/*
 * Copyright 2015-2018 Yubico AB
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include <ctype.h>
#include <string.h>

#include <openssl/bio.h>
#include <openssl/evp.h>
#include <openssl/pem.h>
#include <openssl/x509.h>

#include "openssl-compat.h"
#include "util.h"
#include "insecure_memzero.h"
#include "hash.h"

bool set_component(unsigned char *in_ptr, const BIGNUM *bn, int element_len) {
  return BN_bn2binpad(bn, in_ptr, element_len) == element_len;
}

static unsigned const char sha1oid[] = {0x30, 0x21, 0x30, 0x09, 0x06,
                                        0x05, 0x2B, 0x0E, 0x03, 0x02,
                                        0x1A, 0x05, 0x00, 0x04, 0x14};

static unsigned const char sha256oid[] = {0x30, 0x31, 0x30, 0x0D, 0x06,
                                          0x09, 0x60, 0x86, 0x48, 0x01,
                                          0x65, 0x03, 0x04, 0x02, 0x01,
                                          0x05, 0x00, 0x04, 0x20};

static unsigned const char sha384oid[] = {0x30, 0x41, 0x30, 0x0D, 0x06,
                                          0x09, 0x60, 0x86, 0x48, 0x01,
                                          0x65, 0x03, 0x04, 0x02, 0x02,
                                          0x05, 0x00, 0x04, 0x30};

static unsigned const char sha512oid[] = {0x30, 0x51, 0x30, 0x0D, 0x06,
                                          0x09, 0x60, 0x86, 0x48, 0x01,
                                          0x65, 0x03, 0x04, 0x02, 0x03,
                                          0x05, 0x00, 0x04, 0x40};

static unsigned const char PEM_private_header[] =
  "-----BEGIN PRIVATE KEY-----\n";
static unsigned const char PEM_private_trailer[] =
  "-----END PRIVATE KEY-----\n";
static unsigned const char PEM_public_header[] = "-----BEGIN PUBLIC KEY-----\n";
static unsigned const char PEM_public_trailer[] = "-----END PUBLIC KEY-----\n";
static unsigned const char ed25519private_oid[] = {0x30, 0x2e, 0x02, 0x01,
                                                   0x00, 0x30, 0x05, 0x06,
                                                   0x03, 0x2b, 0x65, 0x70,
                                                   0x04, 0x22, 0x04, 0x20};
static unsigned const char ed25519public_oid[] = {0x30, 0x2a, 0x30, 0x05,
                                                  0x06, 0x03, 0x2b, 0x65,
                                                  0x70, 0x03, 0x21, 0x00};

bool read_ed25519_key(uint8_t *in, size_t in_len, uint8_t *out,
                      size_t *out_len) {

  uint8_t decoded[128];
  size_t decoded_len = sizeof(decoded);
  if (in_len < (28 + 26)) {
    return false;
  }
  if (memcmp(in, PEM_private_header, 28) != 0 ||
      memcmp(in + in_len - 26, PEM_private_trailer, 25) != 0) {
    return false;
  }

  int ret;
  BIO *b64 = NULL;
  BIO *bio = NULL;

  b64 = BIO_new(BIO_f_base64());
  if (b64 == NULL) {
    return false;
  }
  bio = BIO_new(BIO_s_mem());
  if (bio == NULL) {
    BIO_free_all(b64);
    return false;
  }
  BIO_set_flags(b64, BIO_FLAGS_BASE64_NO_NL);
  BIO_push(b64, bio);

  if (BIO_write(bio, in + 28, in_len - 28 - 25) <= 0) {
    BIO_free_all(b64);
    return false;
  }
  if(BIO_flush(bio) != 1) {
    BIO_free_all(b64);
    return false;
  }
  ret = BIO_read(b64, decoded, decoded_len);

  BIO_free_all(b64);

  if (ret != 48) {
    return false;
  }

  if (memcmp(decoded, ed25519private_oid, sizeof(ed25519private_oid)) != 0) {
    return false;
  }

  memcpy(out, decoded + 16, 32);
  *out_len = 32;

  insecure_memzero(decoded, 48);

  return true;
}

bool read_private_key(uint8_t *buf, size_t len, yh_algorithm *algo,
                      uint8_t *bytes, size_t *bytes_len, bool internal_repr) {

  size_t out_len = *bytes_len;
  if (read_ed25519_key(buf, len, bytes, &out_len) == true) {
    *algo = YH_ALGO_EC_ED25519;

    if (internal_repr == true) {
#if (OPENSSL_VERSION_NUMBER < 0x10101000L) || defined(LIBRESSL_VERSION_NUMBER)
      return false;
#else
      EVP_PKEY *pkey =
        EVP_PKEY_new_raw_private_key(EVP_PKEY_ED25519, NULL, bytes, out_len);
      if (pkey == NULL) {
        return false;
      }

      size_t public_key_len = 0xff;
      if (EVP_PKEY_get_raw_public_key(pkey, bytes + 64, &public_key_len) != 1 ||
          public_key_len != 32) {
        EVP_PKEY_free(pkey);
        return false;
      }

      EVP_PKEY_free(pkey);

      if (hash_bytes(bytes, out_len, _SHA512, bytes, bytes_len) == false) {
        return false;
      }

      bytes[0] &= 248;
      bytes[31] &= 127;
      bytes[31] |= 64;

      for (uint8_t i = 0; i < 16; i++) {
        uint8_t tmp = bytes[i];
        bytes[i] = bytes[31 - i];
        bytes[31 - i] = tmp;
      }

      *bytes_len += public_key_len;
#endif
    }

    return true;
  }

  EVP_PKEY *private_key;

  BIO *bio = BIO_new(BIO_s_mem());
  if (bio == NULL) {
    return false;
  }

  if(BIO_write(bio, buf, len) <= 0) {
    BIO_free_all(bio);
    return false;
  }

  private_key = PEM_read_bio_PrivateKey(bio, NULL, NULL, /*password*/ NULL);
  BIO_free_all(bio);
  if (private_key == NULL) {
    return false;
  }

  bool ret = false;

  RSA *rsa = NULL;

  BIGNUM *x = NULL;
  BIGNUM *y = NULL;
  EC_KEY *ec_private = NULL;

  switch (EVP_PKEY_base_id(private_key)) {
    case EVP_PKEY_RSA: {
      rsa = EVP_PKEY_get1_RSA(private_key);
      if (rsa == NULL) {
        goto cleanup;
      }
      unsigned char e[4];
      int size = RSA_size(rsa);
      const BIGNUM *bn_n, *bn_e, *bn_p, *bn_q;

      RSA_get0_key(rsa, &bn_n, &bn_e, NULL);
      RSA_get0_factors(rsa, &bn_p, &bn_q);

      if (set_component(e, bn_e, 3) == false ||
          !(e[0] == 0x01 && e[1] == 0x00 && e[2] == 0x01)) {
        goto cleanup;
      }

      if (size == 256) {
        *algo = YH_ALGO_RSA_2048;
      } else if (size == 384) {
        *algo = YH_ALGO_RSA_3072;
      } else if (size == 512) {
        *algo = YH_ALGO_RSA_4096;
      } else {
        goto cleanup;
      }

      if (set_component(bytes, bn_p, size / 2) == false) {
        goto cleanup;
      }

      if (set_component(bytes + size / 2, bn_q, size / 2) == false) {
        goto cleanup;
      }

      if (internal_repr == true) {
        const BIGNUM *dmp1, *dmq1, *iqmp;
        uint8_t *ptr = bytes + size;

        RSA_get0_crt_params(rsa, &dmp1, &dmq1, &iqmp);
        if (set_component(ptr, dmp1, size / 2) == false) {
          goto cleanup;
        }
        ptr += size / 2;

        if (set_component(ptr, dmq1, size / 2) == false) {
          goto cleanup;
        }
        ptr += size / 2;

        if (set_component(ptr, iqmp, size / 2) == false) {
          goto cleanup;
        }
        ptr += size / 2;

        if (set_component(ptr, bn_n, size) == false) {
          goto cleanup;
        }

        *bytes_len = (size / 2) * 7;
      } else {
        *bytes_len = size;
      }
    } break;

    case EVP_PKEY_EC: {
      ec_private = EVP_PKEY_get1_EC_KEY(private_key);
      if (ec_private == NULL) {
        goto cleanup;
      }

      const BIGNUM *s = EC_KEY_get0_private_key(ec_private);
      const EC_GROUP *group = EC_KEY_get0_group(ec_private);
      int curve = EC_GROUP_get_curve_name(group);
      int size = 0;

      if (curve == NID_X9_62_prime256v1) {
        *algo = YH_ALGO_EC_P256;
        size = 32;
      } else if (curve == NID_secp384r1) {
        *algo = YH_ALGO_EC_P384;
        size = 48;
      } else if (curve == NID_secp521r1) {
        *algo = YH_ALGO_EC_P521;
        size = 66;
      } else if (curve == NID_secp224r1) {
        *algo = YH_ALGO_EC_P224;
        size = 28;
#ifdef NID_brainpoolP256r1
      } else if (curve == NID_brainpoolP256r1) {
        *algo = YH_ALGO_EC_BP256;
        size = 32;
#endif
#ifdef NID_brainpoolP384r1
      } else if (curve == NID_brainpoolP384r1) {
        *algo = YH_ALGO_EC_BP384;
        size = 48;
#endif
#ifdef NID_brainpoolP512r1
      } else if (curve == NID_brainpoolP512r1) {
        *algo = YH_ALGO_EC_BP512;
        size = 64;
#endif
      } else if (curve == NID_secp256k1) {
        *algo = YH_ALGO_EC_K256;
        size = 32;
      } else {
        goto cleanup;
      }

      if (set_component(bytes, s, size) == false) {
        goto cleanup;
      }

      if (internal_repr == true) {
        const EC_POINT *ec_public = EC_KEY_get0_public_key(ec_private);

        x = BN_new();
        if (x == NULL) {
          goto cleanup;
        }

        y = BN_new();
        if (y == NULL) {
          goto cleanup;
        }

        if (EC_POINT_get_affine_coordinates_GFp(group, ec_public, x, y, NULL) ==
            0) {
          goto cleanup;
        }

        uint8_t *ptr = bytes + size;
        if (set_component(ptr, x, size) == false) {
          goto cleanup;
        }
        ptr += size;

        if (set_component(ptr, y, size) == false) {
          goto cleanup;
        }

        *bytes_len = size * 3;
      } else {
        *bytes_len = size;
      }
    } break;

    default:
      goto cleanup;
  }

  ret = true;

cleanup:

  if (rsa != NULL) {
    RSA_free(rsa);
    rsa = NULL;
  }

  if (x != NULL) {
    BN_free(x);
    x = NULL;
  }

  if (y != NULL) {
    BN_free(y);
    y = NULL;
  }

  if (ec_private != NULL) {
    EC_KEY_free(ec_private);
    ec_private = NULL;
  }

  EVP_PKEY_free(private_key);

  return ret;
}

bool read_public_key(uint8_t *buf, size_t len, yh_algorithm *algo,
                      uint8_t *bytes, size_t *bytes_len) {
  BIO *bio = BIO_new(BIO_s_mem());
  if (bio == NULL) {
    return false;
  }

  if(BIO_write(bio, buf, len) <= 0) {
    BIO_free_all(bio);
    return false;
  }

  EVP_PKEY *pubkey = PEM_read_bio_PUBKEY(bio, NULL, NULL, NULL);
  BIO_free_all(bio);
  if (pubkey == NULL) {
    return false;
  }

  if (EVP_PKEY_base_id(pubkey) != EVP_PKEY_EC) {
    EVP_PKEY_free(pubkey);
    return false;
  }

  EC_KEY *ec = EVP_PKEY_get1_EC_KEY(pubkey);
  EVP_PKEY_free(pubkey);
  if (ec == NULL) {
    return false;
  }

  const EC_GROUP *group = EC_KEY_get0_group(ec);
  int curve = EC_GROUP_get_curve_name(group);

  if (curve == NID_X9_62_prime256v1) {
    *algo = YH_ALGO_EC_P256;
  } else if (curve == NID_secp384r1) {
    *algo = YH_ALGO_EC_P384;
  } else if (curve == NID_secp521r1) {
    *algo = YH_ALGO_EC_P521;
  } else if (curve == NID_secp224r1) {
    *algo = YH_ALGO_EC_P224;
#ifdef NID_brainpoolP256r1
  } else if (curve == NID_brainpoolP256r1) {
    *algo = YH_ALGO_EC_BP256;
#endif
#ifdef NID_brainpoolP384r1
  } else if (curve == NID_brainpoolP384r1) {
    *algo = YH_ALGO_EC_BP384;
#endif
#ifdef NID_brainpoolP512r1
  } else if (curve == NID_brainpoolP512r1) {
    *algo = YH_ALGO_EC_BP512;
#endif
  } else if (curve == NID_secp256k1) {
    *algo = YH_ALGO_EC_K256;
  } else {
    *algo = 0;
  }

  size_t data_len = i2o_ECPublicKey(ec, 0);
  if(data_len == 0 || data_len > *bytes_len) {
    EC_KEY_free(ec);
    return false;
  }

  i2o_ECPublicKey(ec, &bytes);
  EC_KEY_free(ec);

  *bytes_len = data_len;
  return true;
}

void format_digest(uint8_t *digest, char *str, uint16_t len) {

  for (uint32_t i = 0; i < len; i++) {
    sprintf(str + (2 * i), "%02x", digest[i]);
  }

  str[2 * len] = '\0';
}

int algo2nid(yh_algorithm algo) {
  switch (algo) {
    case YH_ALGO_EC_P256:
    case YH_ALGO_EC_P256_YUBICO_AUTHENTICATION:
      return NID_X9_62_prime256v1;

    case YH_ALGO_EC_P384:
      return NID_secp384r1;

    case YH_ALGO_EC_P521:
      return NID_secp521r1;

    case YH_ALGO_EC_P224:
      return NID_secp224r1;

    case YH_ALGO_EC_K256:
      return NID_secp256k1;

#ifdef NID_brainpoolP256r1
    case YH_ALGO_EC_BP256:
      return NID_brainpoolP256r1;
#endif

#ifdef NID_brainpoolP384r1
    case YH_ALGO_EC_BP384:
      return NID_brainpoolP384r1;
#endif

#ifdef NID_brainpoolP512r1
    case YH_ALGO_EC_BP512:
      return NID_brainpoolP512r1;
#endif

#ifdef NID_ED25519
    case YH_ALGO_EC_ED25519:
      return NID_ED25519;
#endif

    default:
      return 0;
  }
}

bool algo2type(yh_algorithm algorithm, yh_object_type *type) {

  switch (algorithm) {
    case YH_ALGO_RSA_PKCS1_SHA1:
    case YH_ALGO_RSA_PKCS1_SHA256:
    case YH_ALGO_RSA_PKCS1_SHA384:
    case YH_ALGO_RSA_PKCS1_SHA512:
    case YH_ALGO_RSA_PSS_SHA1:
    case YH_ALGO_RSA_PSS_SHA256:
    case YH_ALGO_RSA_PSS_SHA384:
    case YH_ALGO_RSA_PSS_SHA512:
    case YH_ALGO_RSA_2048:
    case YH_ALGO_RSA_3072:
    case YH_ALGO_RSA_4096:
    case YH_ALGO_EC_P224:
    case YH_ALGO_EC_P256:
    case YH_ALGO_EC_P384:
    case YH_ALGO_EC_P521:
    case YH_ALGO_EC_K256:
    case YH_ALGO_EC_BP256:
    case YH_ALGO_EC_BP384:
    case YH_ALGO_EC_BP512:
    case YH_ALGO_EC_ECDSA_SHA1:
    case YH_ALGO_EC_ECDH:
    case YH_ALGO_RSA_OAEP_SHA1:
    case YH_ALGO_RSA_OAEP_SHA256:
    case YH_ALGO_RSA_OAEP_SHA384:
    case YH_ALGO_RSA_OAEP_SHA512:
    case YH_ALGO_EC_ECDSA_SHA256:
    case YH_ALGO_EC_ECDSA_SHA384:
    case YH_ALGO_EC_ECDSA_SHA512:
    case YH_ALGO_EC_ED25519:
      *type = YH_ASYMMETRIC_KEY;
      break;

    case YH_ALGO_HMAC_SHA1:
    case YH_ALGO_HMAC_SHA256:
    case YH_ALGO_HMAC_SHA384:
    case YH_ALGO_HMAC_SHA512:
      *type = YH_HMAC_KEY;
      break;

    case YH_ALGO_AES128_CCM_WRAP:
    case YH_ALGO_AES192_CCM_WRAP:
    case YH_ALGO_AES256_CCM_WRAP:
      *type = YH_WRAP_KEY;
      break;

    case YH_ALGO_OPAQUE_DATA:
    case YH_ALGO_OPAQUE_X509_CERTIFICATE:
      *type = YH_OPAQUE;
      break;

    case YH_ALGO_TEMPLATE_SSH:
      *type = YH_TEMPLATE;
      break;

    case YH_ALGO_AES128_YUBICO_OTP:
    case YH_ALGO_AES192_YUBICO_OTP:
    case YH_ALGO_AES256_YUBICO_OTP:
      *type = YH_OTP_AEAD_KEY;
      break;

    case YH_ALGO_AES128_YUBICO_AUTHENTICATION:
    case YH_ALGO_EC_P256_YUBICO_AUTHENTICATION:
      *type = YH_AUTHENTICATION_KEY;
      break;

    case YH_ALGO_AES128:
    case YH_ALGO_AES192:
    case YH_ALGO_AES256:
    case YH_ALGO_AES_ECB:
    case YH_ALGO_AES_CBC:
      *type = YH_SYMMETRIC_KEY;
      break;

    case YH_ALGO_MGF1_SHA1:
    case YH_ALGO_MGF1_SHA256:
    case YH_ALGO_MGF1_SHA384:
    case YH_ALGO_MGF1_SHA512:
    default:
      return false;
  }

  return true;
}

int parse_NID(uint8_t *data, uint16_t data_len, const EVP_MD **md_type) {
  if (data_len >= sizeof(sha1oid) &&
      memcmp(sha1oid, data, sizeof(sha1oid)) == 0) {
    *md_type = EVP_sha1();
    return sizeof(sha1oid);
  } else if (data_len >= sizeof(sha256oid) &&
             memcmp(sha256oid, data, sizeof(sha256oid)) == 0) {
    *md_type = EVP_sha256();
    return sizeof(sha256oid);
  } else if (data_len >= sizeof(sha384oid) &&
             memcmp(sha384oid, data, sizeof(sha384oid)) == 0) {
    *md_type = EVP_sha384();
    return sizeof(sha384oid);
  } else if (data_len >= sizeof(sha512oid) &&
             memcmp(sha512oid, data, sizeof(sha512oid)) == 0) {
    *md_type = EVP_sha512();
    return sizeof(sha512oid);
  } else {
    *md_type = EVP_md_null();
    return 0;
  }
}

bool read_file(FILE *fp, uint8_t *buf, size_t *buf_len) {
  size_t n = 0;
  size_t available = *buf_len;
  uint8_t *p = buf;

  do {
    n = fread(p, 1, available, fp);
    available -= n;
    p += n;
  } while (!feof(fp) && !ferror(fp) && available > 0);

  if (ferror(fp)) {
    return false;
  }

  if (!feof(fp) && available == 0) {
    uint8_t b[1];
    n = fread(b, 1, 1, fp);
    if (!feof(fp)) {
      return false;
    }
  }

  *buf_len = p - buf;
  return true;
}

bool base64_decode(const char *in, uint8_t *out, size_t *len) {
  int ret;
  BIO *b64 = NULL;
  BIO *bio = NULL;

  b64 = BIO_new(BIO_f_base64());
  if (b64 == NULL) {
    return false;
  }
  bio = BIO_new(BIO_s_mem());
  if (bio == NULL) {
    BIO_free_all(b64);
    return false;
  }
  BIO_set_flags(b64, BIO_FLAGS_BASE64_NO_NL);
  BIO_push(b64, bio);

  if(BIO_write(bio, in, strlen(in)) <= 0) {
    BIO_free_all(b64);
    return false;
  }
  if(BIO_flush(bio) != 1) {
    BIO_free_all(b64);
    return false;
  }
  ret = BIO_read(b64, out, *len);

  BIO_free_all(b64);

  if (ret <= 0) {
    return false;
  } else {
    *len = ret;
    return true;
  }
}

bool write_file(const uint8_t *buf, size_t buf_len, FILE *fp, format_t format) {

  const uint8_t *p = buf;
  uint8_t *data = NULL;
  size_t length = buf_len;
  size_t written = 0;
  BIO *b64 = NULL;

  if (format == _base64) {
    BIO *bio;
    BUF_MEM *bufferPtr;

    b64 = BIO_new(BIO_f_base64());
    if (b64 == NULL) {
      return false;
    }
    bio = BIO_new(BIO_s_mem());
    if (bio == NULL) {
      BIO_free_all(b64);
      return false;
    }
    bio = BIO_push(b64, bio);

    BIO_set_flags(bio, BIO_FLAGS_BASE64_NO_NL);
    if(BIO_write(bio, buf, buf_len) <= 0) {
      BIO_free_all(bio);
      return false;
    }
    if(BIO_flush(bio) != 1) {
      BIO_free_all(bio);
      return false;
    }
    BIO_get_mem_ptr(bio, &bufferPtr);

    p = (uint8_t *) bufferPtr->data;
    length = bufferPtr->length;
  } else if (format == _hex) {
    data = calloc(buf_len * 2 + 1, 1);
    if (data == NULL) {
      return false;
    }
    for (size_t i = 0; i < buf_len; i++) {
      sprintf((char *) data + i * 2, "%02x", buf[i]);
    }
    p = data;
    length = buf_len * 2;
  } else if (format == _PEM) {
    p = buf;
    length = buf_len;
  }

  do {
    written = fwrite(p, 1, length, fp);
    length -= written;
    p += written;
  } while (!feof(fp) && !ferror(fp) && length > 0);

  if (fp == stdout || fp == stderr) {
    if ( format != _binary ) {
      fprintf(fp, "\n");
    }
  }

  if (b64 != NULL) {
    BIO_free_all(b64);
    b64 = NULL;
  }

  if (data != NULL) {
    free(data);
    data = NULL;
  }

  if (ferror(fp) || feof(fp)) {
    return false;
  }

  fflush(fp);

  return true;
}

bool write_ed25519_key(uint8_t *buf, size_t buf_len, FILE *fp,
                       format_t format) {

  if (format == _base64 || format == _PEM) {
    uint8_t asn1[64];
    uint8_t drop_newline;

    if (fp == stdout || fp == stderr) {
      drop_newline = 1;
    } else {
      drop_newline = 0;
    }

    if (sizeof(ed25519public_oid) + buf_len < buf_len ||
        sizeof(ed25519public_oid) + buf_len > sizeof(asn1)) {
      return false;
    }
    memcpy(asn1, ed25519public_oid, sizeof(ed25519public_oid));
    memcpy(asn1 + sizeof(ed25519public_oid), buf, buf_len);

    if (format == _PEM) {
      write_file(PEM_public_header,
                 sizeof(PEM_public_header) - 1 - drop_newline, fp, _PEM);
    }

    write_file(asn1, sizeof(ed25519public_oid) + buf_len, fp, _base64);
    if (fp != stdout && fp != stderr) {
      uint8_t newline = '\n';
      write_file(&newline, 1, fp, _PEM);
    }

    if (format == _PEM) {
      write_file(PEM_public_trailer,
                 sizeof(PEM_public_trailer) - 1 - drop_newline, fp, _PEM);
    }
  } else if (format == _hex) {
    write_file(buf, buf_len, fp, _hex);

    if (fp != stdout && fp != stderr) {
      uint8_t newline = '\n';
      write_file(&newline, 1, fp, _PEM);
    }
  } else {
    return false; // TODO(adma): _binary?
  }

  return true;
}

bool split_hmac_key(yh_algorithm algorithm, uint8_t *in, size_t in_len,
                    uint8_t *out, size_t *out_len) {

  uint8_t key[128 * 2] = {0};
  uint8_t block_size;

  switch (algorithm) {
    case YH_ALGO_HMAC_SHA1:
      block_size = EVP_MD_block_size(EVP_sha1());
      break;

    case YH_ALGO_HMAC_SHA256:
      block_size = EVP_MD_block_size(EVP_sha256());
      break;

    case YH_ALGO_HMAC_SHA384:
      block_size = EVP_MD_block_size(EVP_sha384());
      break;

    case YH_ALGO_HMAC_SHA512:
      block_size = EVP_MD_block_size(EVP_sha512());
      break;

    default:
      return false;
  }

  if (in_len > block_size) {
    return false; // TODO(adma): hash the key
  }

  memcpy(key, in, in_len);

  for (uint8_t i = 0; i < block_size; i++) {
    out[i] = key[i] ^ 0x36;
    out[i + block_size] = key[i] ^ 0x5c;
  }

  *out_len = 2 * block_size;

  return true;
}

bool get_pubkey_evp(uint8_t *pubkey, size_t pubkey_len,
                    yh_algorithm pubkey_algo, EVP_PKEY **key) {

  RSA *rsa = NULL;
  BIGNUM *e = NULL;
  BIGNUM *n = NULL;
  EC_KEY *ec_key = NULL;
  EC_GROUP *ec_group = NULL;
  EC_POINT *ec_point = NULL;

  if (yh_is_rsa(pubkey_algo)) {
    rsa = RSA_new();
    e = BN_new();
    if (rsa == NULL || e == NULL) {
      goto l_p_k_failure;
    }

    BN_set_word(e, 0x010001);

    n = BN_bin2bn(pubkey, pubkey_len, NULL);
    if (n == NULL) {
      goto l_p_k_failure;
    }

    if (RSA_set0_key(rsa, n, e, NULL) == 0) {
      goto l_p_k_failure;
    }

    n = NULL;
    e = NULL;

    *key = EVP_PKEY_new();
    if (*key == NULL) {
      goto l_p_k_failure;
    }

    if (EVP_PKEY_assign_RSA(*key, rsa) == 0) {
      goto l_p_k_failure;
    }
  } else if (yh_is_ec(pubkey_algo)) {
    ec_key = EC_KEY_new();
    if (ec_key == NULL) {
      goto l_p_k_failure;
    }

    ec_group = EC_GROUP_new_by_curve_name(algo2nid(pubkey_algo));
    if (ec_group == NULL) {
      goto l_p_k_failure;
    }

    // NOTE: this call is important since it makes it a named curve instead of
    // encoded parameters
    EC_GROUP_set_asn1_flag(ec_group, OPENSSL_EC_NAMED_CURVE);

    if (EC_KEY_set_group(ec_key, ec_group) == 0) {
      goto l_p_k_failure;
    }

    ec_point = EC_POINT_new(ec_group);
    if (ec_point == NULL) {
      goto l_p_k_failure;
    }

    uint8_t ec_pubkey[YH_MSG_BUF_SIZE] = {0};
    ec_pubkey[0] = 0x04; // hack to make it a valid ec pubkey.
    memcpy(ec_pubkey + 1, pubkey, pubkey_len);
    if (EC_POINT_oct2point(ec_group, ec_point, ec_pubkey, pubkey_len + 1,
                           NULL) == 0) {
      goto l_p_k_failure;
    }

    if (EC_KEY_set_public_key(ec_key, ec_point) == 0) {
      goto l_p_k_failure;
    }

    *key = EVP_PKEY_new();
    if (*key == NULL) {
      goto l_p_k_failure;
    }

    if (EVP_PKEY_assign_EC_KEY(*key, ec_key) == 0) {
      goto l_p_k_failure;
    }

    EC_POINT_free(ec_point);
    EC_GROUP_free(ec_group);
#if (OPENSSL_VERSION_NUMBER >= 0x10100000L)
  } else if (yh_is_ed(pubkey_algo)) {
    *key = EVP_PKEY_new_raw_public_key(algo2nid(pubkey_algo), NULL, pubkey,
                                       pubkey_len);
    if (*key == NULL) {
      goto l_p_k_failure;
    }
#endif
  } else {
    goto l_p_k_failure;
  }

  return true;

l_p_k_failure:
  EC_POINT_free(ec_point);
  EC_GROUP_free(ec_group);
  EC_KEY_free(ec_key);
  RSA_free(rsa);
  BN_free(n);
  BN_free(e);

  return false;
}
