/*
* Copyright(c) 2024 Intel Corporation
* SPDX - License - Identifier: BSD - 2 - Clause - Patent
*/

#include "gtest/gtest.h"
#include "random.h"
#include "Quant.h"
#include "Quant_sse4_1.h"
#include "Quant_avx2.h"

#define QUANT_MAX_SIZE 128
#define GROUP_SIZE     4
#define GCLI_SIZE      (QUANT_MAX_SIZE + GROUP_SIZE - 1) / GROUP_SIZE

enum RAND_TYPE { RAND_CUSTOM = 0, RAND_ONE, RAND_ZERO, RAND_FULL, RAND_SIZE };

struct quantInput {
    uint32_t buf_len;
    uint32_t gtli;
    QUANT_TYPE dq_type;
    uint16_t expected_output[QUANT_MAX_SIZE];
};

static uint16_t quant_input[QUANT_MAX_SIZE]{
    0x8001, 0x2,    0x1,    0,      0x16,   0x80f3, 0xf8,   0x80e0, 0x80de, 0x8655, 0x86a1, 0x82b8, 0x805c, 0x113,  0x81fd,
    0x8009, 0,      0,      0,      0,      0x8bbf, 0x328,  0x694,  0x9970, 0x2226, 0x9aec, 0xd06,  0xac59, 0x10,   0x801c,
    0x800c, 0x800a, 0x8345, 0x83b4, 0x1d,   0x35f,  0x8001, 0x8001, 0x8001, 0x1,    0x6,    0x806c, 0xb6,   0x44,   0x6,
    0x5,    0x6,    0x3,    0x10f8, 0x19b4, 0x8f20, 0x1d27, 0x23dc, 0x824e, 0x37cd, 0x9b77, 0x8045, 0x8048, 0x8044, 0x8041,
    0x3,    0x8003, 0,      0x1,    0x8,    0x9,    0xb,    0x8,    0x97a5, 0x1ddf, 0x17c3, 0x1917, 0xbd,   0x19a,  0xc4,
    0x800c, 0x81fb, 0x372,  0x85e2, 0x392,  0xb2be, 0x83df, 0xe6,   0x8bf5, 0x991c, 0x3655, 0x3354, 0x2a1d, 0x8003, 0x2,
    0x8001, 0x1,    0x9a8a, 0x2875, 0x3a2d, 0xa720, 0x8031, 0x802c, 0x9,    0x9,    0x820c, 0x366,  0x81e5, 0x276,  0x8080,
    0xc1,   0x80f6, 0x802c, 0,      0,      0,      0,      0x8b38, 0x7c2,  0x645,  0xc1e,  0,      0,      0,      0,
    0x2,    0x8007, 0xb,    0x800c, 0,      0,      0,      0};

static uint8_t gclis_input[GCLI_SIZE]{2, 8,  12, 9,  0,  13, 14, 5,  10, 1,  8, 3, 13, 14, 7, 2,
                                      4, 13, 9,  11, 14, 14, 2,  14, 6,  10, 8, 0, 12, 0,  4, 0};

static quantInput quant_test_vector[5] = {
    /* Test 1 */
    {100,                 /*buf_len*/
     0,                   /*gtli*/
     QUANT_TYPE_DEADZONE, /*dq_type*/
     {
         0x8001, 0x02,   0x01,   0000,   0x16,   0x80f3, 0xf8,   0x80e0, 0x80de, 0x8655, 0x86a1, 0x82b8, 0x805c, 0x113,  0x81fd,
         0x8009, 0000,   0000,   0000,   0000,   0x8bbf, 0x328,  0x694,  0x9970, 0x2226, 0x9aec, 0xd06,  0xac59, 0x10,   0x801c,
         0x800c, 0x800a, 0x8345, 0x83b4, 0x1d,   0x35f,  0x8001, 0x8001, 0x8001, 0x01,   0x06,   0x806c, 0xb6,   0x44,   0x06,
         0x05,   0x06,   0x03,   0x10f8, 0x19b4, 0x8f20, 0x1d27, 0x23dc, 0x824e, 0x37cd, 0x9b77, 0x8045, 0x8048, 0x8044, 0x8041,
         0x03,   0x8003, 0000,   0x01,   0x08,   0x09,   0x0b,   0x08,   0x97a5, 0x1ddf, 0x17c3, 0x1917, 0xbd,   0x19a,  0xc4,
         0x800c, 0x81fb, 0x372,  0x85e2, 0x392,  0xb2be, 0x83df, 0xe6,   0x8bf5, 0x991c, 0x3655, 0x3354, 0x2a1d, 0x8003, 0x02,
         0x8001, 0x01,   0x9a8a, 0x2875, 0x3a2d, 0xa720, 0x8031, 0x802c, 0x09,   0x09,   0x820c, 0x366,  0x81e5, 0x276,  0x8080,
         0xc1,   0x80f6, 0x802c, 0000,   0000,   0000,   0000,   0x8b38, 0x7c2,  0x645,  0xc1e,  0000,   0000,   0000,   0000,
         0x02,   0x8007, 0x0b,   0x800c, 0000,   0000,   0000,   0000,
     }},
    /* Test 2 */
    {90,                  /*buf_len*/
     3,                   /*gtli*/
     QUANT_TYPE_DEADZONE, /*dq_type*/
     {
         0000,   0000,   0000,   0000,   0x10,   0x80f0, 0xf8,   0x80e0, 0x80d8, 0x8650, 0x86a0, 0x82b8, 0x8058, 0x110,  0x81f8,
         0x8008, 0000,   0000,   0000,   0000,   0x8bb8, 0x328,  0x690,  0x9970, 0x2220, 0x9ae8, 0xd00,  0xac58, 0x10,   0x8018,
         0x8008, 0x8008, 0x8340, 0x83b0, 0x18,   0x358,  0000,   0000,   0000,   0000,   0000,   0x8068, 0xb0,   0x40,   0000,
         0000,   0000,   0000,   0x10f8, 0x19b0, 0x8f20, 0x1d20, 0x23d8, 0x8248, 0x37c8, 0x9b70, 0x8040, 0x8048, 0x8040, 0x8040,
         0000,   0000,   0000,   0000,   0x08,   0x08,   0x08,   0x08,   0x97a0, 0x1dd8, 0x17c0, 0x1910, 0xb8,   0x198,  0xc0,
         0x8008, 0x81f8, 0x370,  0x85e0, 0x390,  0xb2b8, 0x83d8, 0xe0,   0x8bf0, 0x9918, 0x3650, 0x3350, 0x2a18, 0000,   0000,
         0x8001, 0x01,   0x9a8a, 0x2875, 0x3a2d, 0xa720, 0x8031, 0x802c, 0x09,   0x09,   0x820c, 0x366,  0x81e5, 0x276,  0x8080,
         0xc1,   0x80f6, 0x802c, 0000,   0000,   0000,   0000,   0x8b38, 0x7c2,  0x645,  0xc1e,  0000,   0000,   0000,   0000,
         0x02,   0x8007, 0x0b,   0x800c, 0000,   0000,   0000,   0000,
     }},
    /* Test 3 */
    {80,                 /*buf_len*/
     0,                  /*gtli*/
     QUANT_TYPE_UNIFORM, /*dq_type*/
     {
         0x8001, 0x02,   0x01,   0000,   0x16,   0x80f3, 0xf8,   0x80e0, 0x80de, 0x8655, 0x86a1, 0x82b8, 0x805c, 0x113,  0x81fd,
         0x8009, 0000,   0000,   0000,   0000,   0x8bbf, 0x328,  0x694,  0x9970, 0x2226, 0x9aec, 0xd06,  0xac59, 0x10,   0x801c,
         0x800c, 0x800a, 0x8345, 0x83b4, 0x1d,   0x35f,  0x8001, 0x8001, 0x8001, 0x01,   0x06,   0x806c, 0xb6,   0x44,   0x06,
         0x05,   0x06,   0x03,   0x10f8, 0x19b4, 0x8f20, 0x1d27, 0x23dc, 0x824e, 0x37cd, 0x9b77, 0x8045, 0x8048, 0x8044, 0x8041,
         0x03,   0x8003, 0000,   0x01,   0x08,   0x09,   0x0b,   0x08,   0x97a5, 0x1ddf, 0x17c3, 0x1917, 0xbd,   0x19a,  0xc4,
         0x800c, 0x81fb, 0x372,  0x85e2, 0x392,  0xb2be, 0x83df, 0xe6,   0x8bf5, 0x991c, 0x3655, 0x3354, 0x2a1d, 0x8003, 0x02,
         0x8001, 0x01,   0x9a8a, 0x2875, 0x3a2d, 0xa720, 0x8031, 0x802c, 0x09,   0x09,   0x820c, 0x366,  0x81e5, 0x276,  0x8080,
         0xc1,   0x80f6, 0x802c, 0000,   0000,   0000,   0000,   0x8b38, 0x7c2,  0x645,  0xc1e,  0000,   0000,   0000,   0000,
         0x02,   0x8007, 0x0b,   0x800c, 0000,   0000,   0000,   0000,
     }},
    /* Test 4 */
    {70,                 /*buf_len*/
     4,                  /*gtli*/
     QUANT_TYPE_UNIFORM, /*dq_type*/
     {
         0000,   0000,   0000,   0000,   0x10,   0x80f0, 0xf0,   0x80e0, 0x80e0, 0x8650, 0x86a0, 0x82b0, 0x8060, 0x110,  0x81f0,
         0x8010, 0000,   0000,   0000,   0000,   0x8bc0, 0x320,  0x690,  0x9970, 0x2220, 0x9af0, 0xd00,  0xac50, 0x10,   0x8010,
         0x8010, 0000,   0x8340, 0x83b0, 0x20,   0x360,  0000,   0000,   0000,   0000,   0000,   0x8070, 0xb0,   0x40,   0000,
         0000,   0000,   0000,   0x10f0, 0x19b0, 0x8f20, 0x1d20, 0x23d0, 0x8250, 0x37c0, 0x9b70, 0x8040, 0x8040, 0x8040, 0x8040,
         0000,   0000,   0000,   0000,   0000,   0000,   0000,   0000,   0x97a0, 0x1dd0, 0x17c3, 0x1917, 0xbd,   0x19a,  0xc4,
         0x800c, 0x81fb, 0x372,  0x85e2, 0x392,  0xb2be, 0x83df, 0xe6,   0x8bf5, 0x991c, 0x3655, 0x3354, 0x2a1d, 0x8003, 0x02,
         0x8001, 0x01,   0x9a8a, 0x2875, 0x3a2d, 0xa720, 0x8031, 0x802c, 0x09,   0x09,   0x820c, 0x366,  0x81e5, 0x276,  0x8080,
         0xc1,   0x80f6, 0x802c, 0000,   0000,   0000,   0000,   0x8b38, 0x7c2,  0x645,  0xc1e,  0000,   0000,   0000,   0000,
         0x02,   0x8007, 0x0b,   0x800c, 0000,   0000,   0000,   0000,
     }},
    /* Test 5 */
    {60,                 /*buf_len*/
     6,                  /*gtli*/
     QUANT_TYPE_UNIFORM, /*dq_type*/
     {
         0000,   0000,   0000,   0000,   0000,   0x80c0, 0xc0,   0x80c0, 0x80c0, 0x8640, 0x8680, 0x82c0, 0x8040, 0x100,  0x81c0,
         0000,   0000,   0000,   0000,   0000,   0x8bc0, 0x340,  0x680,  0x9940, 0x2200, 0x9ac0, 0xd00,  0xac40, 0000,   0000,
         0000,   0000,   0x8340, 0x8380, 0000,   0x340,  0000,   0000,   0000,   0000,   0000,   0x8040, 0x80,   0x40,   0000,
         0000,   0000,   0000,   0x1100, 0x1980, 0x8f00, 0x1d00, 0x23c0, 0x8240, 0x37c0, 0x9b80, 0x8040, 0x8040, 0x8040, 0x8040,
         0x03,   0x8003, 0000,   0x01,   0x08,   0x09,   0x0b,   0x08,   0x97a5, 0x1ddf, 0x17c3, 0x1917, 0xbd,   0x19a,  0xc4,
         0x800c, 0x81fb, 0x372,  0x85e2, 0x392,  0xb2be, 0x83df, 0xe6,   0x8bf5, 0x991c, 0x3655, 0x3354, 0x2a1d, 0x8003, 0x02,
         0x8001, 0x01,   0x9a8a, 0x2875, 0x3a2d, 0xa720, 0x8031, 0x802c, 0x09,   0x09,   0x820c, 0x366,  0x81e5, 0x276,  0x8080,
         0xc1,   0x80f6, 0x802c, 0000,   0000,   0000,   0000,   0x8b38, 0x7c2,  0x645,  0xc1e,  0000,   0000,   0000,   0000,
         0x02,   0x8007, 0x0b,   0x800c, 0000,   0000,   0000,   0000,
     }},
};

class QuantTest : public ::testing::TestWithParam<int> {
  protected:
    static svt_jxs_test_tool::SVTRandom* rand;
    RAND_TYPE rand_type;
    uint32_t buffer_size;
    uint32_t gtli;
    QUANT_TYPE dq_type;
    uint16_t* buf_ref;
    uint16_t* buf_mod;
    uint8_t* gclis;

  protected:
    static void SetUpTestCase() {
        rand = new svt_jxs_test_tool::SVTRandom(32, false);
    } //Prepare data for all tests
    static void TearDownTestCase() {
        delete rand;
    } //Release data for all tests

    void set_up_corectness() {
        buffer_size = 0;
        rand_type = (enum RAND_TYPE)GetParam();
        //buf_ref is pointing to quant_test_vector[].expected_output
        buf_ref = NULL;
        buf_mod = (uint16_t*)malloc(sizeof(quant_input));
        gclis = (uint8_t*)malloc(sizeof(gclis_input));
    }

    virtual void TearDown() {
        if (gclis)
            free(gclis);
        if (buf_ref)
            free(buf_ref);
        if (buf_mod)
            free(buf_mod);
    }

    void tear_down_corectness() {
        if (buf_mod)
            free(buf_mod);
        buf_mod = NULL;
        if (gclis)
            free(gclis);
        gclis = NULL;
        buf_ref = NULL;
    }

    void set_corectness_data(quantInput* input) {
        gtli = input->gtli;
        buf_ref = input->expected_output;
        buffer_size = input->buf_len;
        dq_type = input->dq_type;
    }

    void run_correctness(void (*quant_fn_ptr)(uint16_t* buf, uint32_t size, uint8_t* gclis, uint32_t group_size, uint8_t gtli,
                                              QUANT_TYPE quant_type)) {
        set_up_corectness();
        uint32_t test_size = sizeof(quant_test_vector) / sizeof(quant_test_vector[0]);

        for (uint32_t test_idx = 0; test_idx < test_size; test_idx++) {
            set_corectness_data(&quant_test_vector[test_idx]);

            memcpy(gclis, gclis_input, sizeof(gclis_input));
            memcpy(buf_mod, quant_input, sizeof(quant_input));

            quant_fn_ptr(buf_mod, buffer_size, gclis, GROUP_SIZE, gtli, dq_type);

            if (memcmp(buf_ref, buf_mod, sizeof(quant_input))) {
                printf("quantization calculation error for test %d\n", test_idx);
                ASSERT_EQ(memcmp(buf_ref, buf_mod, sizeof(quant_input)), 0);
            }

            if (memcmp(gclis_input, gclis, sizeof(gclis_input))) {
                printf("quantization calculation error for test %d\n gclis differ!\n", test_idx);
                ASSERT_EQ(memcmp(gclis_input, gclis, sizeof(gclis_input)), 0);
            }
        }
        tear_down_corectness();
    }
};

svt_jxs_test_tool::SVTRandom* QuantTest::rand = NULL;

TEST_P(QuantTest, corectness_test_c) {
    run_correctness(quantization_c);
}
TEST_P(QuantTest, corectness_test_sse41) {
    run_correctness(quantization_sse4_1);
}

TEST_P(QuantTest, corectness_test_avx2) {
    run_correctness(quantization_avx2);
}

INSTANTIATE_TEST_SUITE_P(Quant, QuantTest, ::testing::Range(0, (int)RAND_SIZE));
