/* BEGIN_HEADER */
#include "mbedtls/gcm.h"

/* Use the multipart interface to process the encrypted data in two parts
 * and check that the output matches the expected output.
 * The context must have been set up with the key. */
static int check_multipart( mbedtls_gcm_context *ctx,
                            int mode,
                            const data_t *iv,
                            const data_t *add,
                            const data_t *input,
                            const data_t *expected_output,
                            const data_t *tag,
                            size_t n1,
                            size_t n1_add)
{
    int ok = 0;
    uint8_t *output = NULL;
    size_t n2 = input->len - n1;
    size_t n2_add = add->len - n1_add;
    size_t olen;

    /* Sanity checks on the test data */
    TEST_ASSERT( n1 <= input->len );
    TEST_ASSERT( n1_add <= add->len );
    TEST_EQUAL( input->len, expected_output->len );

    TEST_EQUAL( 0, mbedtls_gcm_starts( ctx, mode,
                                         iv->x, iv->len ) );
    TEST_EQUAL( 0, mbedtls_gcm_update_ad( ctx, add->x, n1_add ) );
    TEST_EQUAL( 0, mbedtls_gcm_update_ad( ctx, add->x + n1_add, n2_add ) );

    /* Allocate a tight buffer for each update call. This way, if the function
     * tries to write beyond the advertised required buffer size, this will
     * count as an overflow for memory sanitizers and static checkers. */
    ASSERT_ALLOC( output, n1 );
    olen = 0xdeadbeef;
    TEST_EQUAL( 0, mbedtls_gcm_update( ctx, input->x, n1, output, n1, &olen ) );
    TEST_EQUAL( n1, olen );
    ASSERT_COMPARE( output, olen, expected_output->x, n1 );
    mbedtls_free( output );
    output = NULL;

    ASSERT_ALLOC( output, n2 );
    olen = 0xdeadbeef;
    TEST_EQUAL( 0, mbedtls_gcm_update( ctx, input->x + n1, n2, output, n2, &olen ) );
    TEST_EQUAL( n2, olen );
    ASSERT_COMPARE( output, olen, expected_output->x + n1, n2 );
    mbedtls_free( output );
    output = NULL;

    ASSERT_ALLOC( output, tag->len );
    TEST_EQUAL( 0, mbedtls_gcm_finish( ctx, NULL, 0, output, tag->len ) );
    ASSERT_COMPARE( output, tag->len, tag->x, tag->len );
    mbedtls_free( output );
    output = NULL;

    ok = 1;
exit:
    mbedtls_free( output );
    return( ok );
}

static void check_cipher_with_empty_ad( mbedtls_gcm_context *ctx,
                                        int mode,
                                        const data_t *iv,
                                        const data_t *input,
                                        const data_t *expected_output,
                                        const data_t *tag,
                                        size_t ad_update_count)
{
    size_t n;
    uint8_t *output = NULL;
    size_t olen;

    /* Sanity checks on the test data */
    TEST_EQUAL( input->len, expected_output->len );

    TEST_EQUAL( 0, mbedtls_gcm_starts( ctx, mode,
                                       iv->x, iv->len ) );

    for( n = 0; n < ad_update_count; n++ )
    {
        TEST_EQUAL( 0, mbedtls_gcm_update_ad( ctx, NULL, 0 ) );
    }

    /* Allocate a tight buffer for each update call. This way, if the function
     * tries to write beyond the advertised required buffer size, this will
     * count as an overflow for memory sanitizers and static checkers. */
    ASSERT_ALLOC( output, input->len );
    olen = 0xdeadbeef;
    TEST_EQUAL( 0, mbedtls_gcm_update( ctx, input->x, input->len, output, input->len, &olen ) );
    TEST_EQUAL( input->len, olen );
    ASSERT_COMPARE( output, olen, expected_output->x, input->len );
    mbedtls_free( output );
    output = NULL;

    ASSERT_ALLOC( output, tag->len );
    TEST_EQUAL( 0, mbedtls_gcm_finish( ctx, NULL, 0, output, tag->len ) );
    ASSERT_COMPARE( output, tag->len, tag->x, tag->len );

exit:
    mbedtls_free( output );
}

static void check_empty_cipher_with_ad( mbedtls_gcm_context *ctx,
                                       int mode,
                                       const data_t *iv,
                                       const data_t *add,
                                       const data_t *tag,
                                       size_t cipher_update_count)
{
    size_t olen;
    size_t n;
    uint8_t* output_tag = NULL;

    TEST_EQUAL( 0, mbedtls_gcm_starts( ctx, mode, iv->x, iv->len ) );
    TEST_EQUAL( 0, mbedtls_gcm_update_ad( ctx, add->x, add->len ) );

    for( n = 0; n < cipher_update_count; n++ )
    {
        olen = 0xdeadbeef;
        TEST_EQUAL( 0, mbedtls_gcm_update( ctx, NULL, 0, NULL, 0, &olen ) );
        TEST_EQUAL( 0, olen );
    }

    ASSERT_ALLOC( output_tag, tag->len );
    TEST_EQUAL( 0, mbedtls_gcm_finish( ctx, NULL, 0, output_tag, tag->len ) );
    ASSERT_COMPARE( output_tag, tag->len, tag->x, tag->len );

exit:
    mbedtls_free( output_tag );
}

static void check_no_cipher_no_ad( mbedtls_gcm_context *ctx,
                                   int mode,
                                   const data_t *iv,
                                   const data_t *tag )
{
    uint8_t *output = NULL;

    TEST_EQUAL( 0, mbedtls_gcm_starts( ctx, mode,
                                       iv->x, iv->len ) );
    ASSERT_ALLOC( output, tag->len );
    TEST_EQUAL( 0, mbedtls_gcm_finish( ctx, NULL, 0, output, tag->len ) );
    ASSERT_COMPARE( output, tag->len, tag->x, tag->len );

exit:
    mbedtls_free( output );
}

/* END_HEADER */

/* BEGIN_DEPENDENCIES
 * depends_on:MBEDTLS_GCM_C
 * END_DEPENDENCIES
 */

/* BEGIN_CASE */
void gcm_bad_parameters( int cipher_id, int direction,
                         data_t *key_str, data_t *src_str,
                         data_t *iv_str, data_t *add_str,
                         int tag_len_bits, int gcm_result )
{
    unsigned char output[128];
    unsigned char tag_output[16];
    mbedtls_gcm_context ctx;
    size_t tag_len = tag_len_bits / 8;

    mbedtls_gcm_init( &ctx );

    memset( output, 0x00, sizeof( output ) );
    memset( tag_output, 0x00, sizeof( tag_output ) );

    TEST_ASSERT( mbedtls_gcm_setkey( &ctx, cipher_id, key_str->x, key_str->len * 8 ) == 0 );
    TEST_ASSERT( mbedtls_gcm_crypt_and_tag( &ctx, direction, src_str->len, iv_str->x, iv_str->len,
                 add_str->x, add_str->len, src_str->x, output, tag_len, tag_output ) == gcm_result );

exit:
    mbedtls_gcm_free( &ctx );
}
/* END_CASE */

/* BEGIN_CASE */
void gcm_encrypt_and_tag( int cipher_id, data_t * key_str,
                          data_t * src_str, data_t * iv_str,
                          data_t * add_str, data_t * dst,
                          int tag_len_bits, data_t * tag,
                          int init_result )
{
    unsigned char output[128];
    unsigned char tag_output[16];
    mbedtls_gcm_context ctx;
    size_t tag_len = tag_len_bits / 8;
    size_t n1;
    size_t n1_add;

    mbedtls_gcm_init( &ctx );

    memset(output, 0x00, 128);
    memset(tag_output, 0x00, 16);


    TEST_ASSERT( mbedtls_gcm_setkey( &ctx, cipher_id, key_str->x, key_str->len * 8 ) == init_result );
    if( init_result == 0 )
    {
        TEST_ASSERT( mbedtls_gcm_crypt_and_tag( &ctx, MBEDTLS_GCM_ENCRYPT, src_str->len, iv_str->x, iv_str->len, add_str->x, add_str->len, src_str->x, output, tag_len, tag_output ) == 0 );

        ASSERT_COMPARE( output, src_str->len, dst->x, dst->len );
        ASSERT_COMPARE( tag_output, tag_len, tag->x, tag->len );

        for( n1 = 0; n1 <= src_str->len; n1 += 1 )
        {
            for( n1_add = 0; n1_add <= add_str->len; n1_add += 1 )
            {
                mbedtls_test_set_step( n1 * 10000 + n1_add );
                if( !check_multipart( &ctx, MBEDTLS_GCM_ENCRYPT,
                                    iv_str, add_str, src_str,
                                    dst, tag,
                                    n1, n1_add ) )
                    goto exit;
            }
        }
    }

exit:
    mbedtls_gcm_free( &ctx );
}
/* END_CASE */

/* BEGIN_CASE */
void gcm_decrypt_and_verify( int cipher_id, data_t * key_str,
                             data_t * src_str, data_t * iv_str,
                             data_t * add_str, int tag_len_bits,
                             data_t * tag_str, char * result,
                             data_t * pt_result, int init_result )
{
    unsigned char output[128];
    mbedtls_gcm_context ctx;
    int ret;
    size_t tag_len = tag_len_bits / 8;
    size_t n1;
    size_t n1_add;

    mbedtls_gcm_init( &ctx );

    memset(output, 0x00, 128);


    TEST_ASSERT( mbedtls_gcm_setkey( &ctx, cipher_id, key_str->x, key_str->len * 8 ) == init_result );
    if( init_result == 0 )
    {
        ret = mbedtls_gcm_auth_decrypt( &ctx, src_str->len, iv_str->x, iv_str->len, add_str->x, add_str->len, tag_str->x, tag_len, src_str->x, output );

        if( strcmp( "FAIL", result ) == 0 )
        {
            TEST_ASSERT( ret == MBEDTLS_ERR_GCM_AUTH_FAILED );
        }
        else
        {
            TEST_ASSERT( ret == 0 );
            ASSERT_COMPARE( output, src_str->len, pt_result->x, pt_result->len );

            for( n1 = 0; n1 <= src_str->len; n1 += 1 )
            {
                for( n1_add = 0; n1_add <= add_str->len; n1_add += 1 )
                {
                    mbedtls_test_set_step( n1 * 10000 + n1_add );
                    if( !check_multipart( &ctx, MBEDTLS_GCM_DECRYPT,
                                        iv_str, add_str, src_str,
                                        pt_result, tag_str,
                                        n1, n1_add ) )
                        goto exit;
                }
            }
        }
    }

exit:
    mbedtls_gcm_free( &ctx );
}
/* END_CASE */

/* BEGIN_CASE */
void gcm_decrypt_and_verify_empty_cipher( int cipher_id,
                                          data_t * key_str,
                                          data_t * iv_str,
                                          data_t * add_str,
                                          data_t * tag_str,
                                          int cipher_update_calls )
{
    mbedtls_gcm_context ctx;

    mbedtls_gcm_init( &ctx );

    TEST_ASSERT( mbedtls_gcm_setkey( &ctx, cipher_id, key_str->x, key_str->len * 8 ) == 0 );
    check_empty_cipher_with_ad( &ctx, MBEDTLS_GCM_DECRYPT,
                                iv_str, add_str, tag_str,
                                cipher_update_calls );

    mbedtls_gcm_free( &ctx );
}
/* END_CASE */

/* BEGIN_CASE */
void gcm_decrypt_and_verify_empty_ad( int cipher_id,
                                      data_t * key_str,
                                      data_t * iv_str,
                                      data_t * src_str,
                                      data_t * tag_str,
                                      data_t * pt_result,
                                      int ad_update_calls )
{
    mbedtls_gcm_context ctx;

    mbedtls_gcm_init( &ctx );

    TEST_ASSERT( mbedtls_gcm_setkey( &ctx, cipher_id, key_str->x, key_str->len * 8 ) == 0 );
    check_cipher_with_empty_ad( &ctx, MBEDTLS_GCM_DECRYPT,
                                iv_str, src_str, pt_result, tag_str,
                                ad_update_calls );

    mbedtls_gcm_free( &ctx );
}
/* END_CASE */

/* BEGIN_CASE */
void gcm_decrypt_and_verify_no_ad_no_cipher( int cipher_id,
                                             data_t * key_str,
                                             data_t * iv_str,
                                             data_t * tag_str )
{
    mbedtls_gcm_context ctx;

    mbedtls_gcm_init( &ctx );

    TEST_ASSERT( mbedtls_gcm_setkey( &ctx, cipher_id, key_str->x, key_str->len * 8 ) == 0 );
    check_no_cipher_no_ad( &ctx, MBEDTLS_GCM_DECRYPT,
                           iv_str, tag_str );

    mbedtls_gcm_free( &ctx );
}
/* END_CASE */

/* BEGIN_CASE */
void gcm_encrypt_and_tag_empty_cipher( int cipher_id,
                                       data_t * key_str,
                                       data_t * iv_str,
                                       data_t * add_str,
                                       data_t * tag_str,
                                       int cipher_update_calls )
{
    mbedtls_gcm_context ctx;

    mbedtls_gcm_init( &ctx );

    TEST_ASSERT( mbedtls_gcm_setkey( &ctx, cipher_id, key_str->x, key_str->len * 8 ) == 0 );
    check_empty_cipher_with_ad( &ctx, MBEDTLS_GCM_ENCRYPT,
                                iv_str, add_str, tag_str,
                                cipher_update_calls );

exit:
    mbedtls_gcm_free( &ctx );
}
/* END_CASE */

/* BEGIN_CASE */
void gcm_encrypt_and_tag_empty_ad( int cipher_id,
                                   data_t * key_str,
                                   data_t * iv_str,
                                   data_t * src_str,
                                   data_t * dst,
                                   data_t * tag_str,
                                   int ad_update_calls )
{
    mbedtls_gcm_context ctx;

    mbedtls_gcm_init( &ctx );

    TEST_ASSERT( mbedtls_gcm_setkey( &ctx, cipher_id, key_str->x, key_str->len * 8 ) == 0 );
    check_cipher_with_empty_ad( &ctx, MBEDTLS_GCM_ENCRYPT,
                                iv_str, src_str, dst, tag_str,
                                ad_update_calls );

exit:
    mbedtls_gcm_free( &ctx );
}
/* END_CASE */

/* BEGIN_CASE */
void gcm_encrypt_and_verify_no_ad_no_cipher( int cipher_id,
                                             data_t * key_str,
                                             data_t * iv_str,
                                             data_t * tag_str )
{
    mbedtls_gcm_context ctx;

    mbedtls_gcm_init( &ctx );

    TEST_ASSERT( mbedtls_gcm_setkey( &ctx, cipher_id, key_str->x, key_str->len * 8 ) == 0 );
    check_no_cipher_no_ad( &ctx, MBEDTLS_GCM_ENCRYPT,
                           iv_str, tag_str );

    mbedtls_gcm_free( &ctx );
}
/* END_CASE */

/* BEGIN_CASE depends_on:NOT_DEFINED */
void gcm_invalid_param( )
{
    mbedtls_gcm_context ctx;
    unsigned char valid_buffer[] = { 0x01, 0x02, 0x03, 0x04, 0x05, 0x06 };
    mbedtls_cipher_id_t valid_cipher = MBEDTLS_CIPHER_ID_AES;
    int invalid_bitlen = 1;

    mbedtls_gcm_init( &ctx );

    /* mbedtls_gcm_setkey */
    TEST_EQUAL(
        MBEDTLS_ERR_GCM_BAD_INPUT,
        mbedtls_gcm_setkey( &ctx, valid_cipher, valid_buffer, invalid_bitlen ) );

exit:
    mbedtls_gcm_free( &ctx );
}
/* END_CASE */

/* BEGIN_CASE depends_on:MBEDTLS_SELF_TEST */
void gcm_selftest(  )
{
    TEST_ASSERT( mbedtls_gcm_self_test( 1 ) == 0 );
}
/* END_CASE */