/*
 * Copyright (C) by
 *   MetraLabs GmbH (MLAB), GERMANY
 * and
 *   Neuroinformatics and Cognitive Robotics Labs (NICR) at TU Ilmenau, GERMANY
 * All rights reserved.
 *
 * Contact: info@mira-project.org
 *
 * Commercial Usage:
 *   Licensees holding valid commercial licenses may use this file in
 *   accordance with the commercial license agreement provided with the
 *   software or, alternatively, in accordance with the terms contained in
 *   a written agreement between you and MLAB or NICR.
 *
 * GNU General Public License Usage:
 *   Alternatively, this file may be used under the terms of the GNU
 *   General Public License version 3.0 as published by the Free Software
 *   Foundation and appearing in the file LICENSE.GPL3 included in the
 *   packaging of this file. Please review the following information to
 *   ensure the GNU General Public License version 3.0 requirements will be
 *   met: http://www.gnu.org/copyleft/gpl.html.
 *   Alternatively you may (at your option) use any later version of the GNU
 *   General Public License if such license has been publicly approved by
 *   MLAB and NICR (or its successors, if any).
 *
 * IN NO EVENT SHALL "MLAB" OR "NICR" BE LIABLE TO ANY PARTY FOR DIRECT,
 * INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF
 * THE USE OF THIS SOFTWARE AND ITS DOCUMENTATION, EVEN IF "MLAB" OR
 * "NICR" HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 * "MLAB" AND "NICR" SPECIFICALLY DISCLAIM ANY WARRANTIES, INCLUDING,
 * BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
 * FITNESS FOR A PARTICULAR PURPOSE. THE SOFTWARE PROVIDED HEREUNDER IS
 * ON AN "AS IS" BASIS, AND "MLAB" AND "NICR" HAVE NO OBLIGATION TO
 * PROVIDE MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS OR MODIFICATIONS.
 */

/**
 * @file RSASignature.C
 *    Implementation of RSASignature.h
 *    on a RSA encryption.
 *
 * @author Christian Martin
 * @date   2023/12/xx
 */

#include <security/RSASignature.h>

#include <boost/format.hpp>

#include <utils/ToString.h>

#include "../OpenSSLHelper.h"

namespace mira {

///////////////////////////////////////////////////////////////////////////////

RSASignature::RSASignature() :
	mSize(0),
	mData(NULL)
{
	// Ensure, that we still can use (unsafe) SHA1 digest signatures,
	// which are disabled by default on some systems (like RedhHat-10).
	// I didn't found an API function to to this...
	// (Christian, 2025-05-26)
	setenv("OPENSSL_ENABLE_SHA1_SIGNATURES", "1", 1);

	OpenSSLCleanup::instance();
}

RSASignature::RSASignature(const RSASignature& signature) :
	mSize(signature.mSize),
	mData(NULL)
{
	OpenSSLCleanup::instance();

	mData = new uint8[mSize];
	memcpy(mData, signature.mData, mSize);
}

RSASignature::~RSASignature()
{
	delete [] mData;
	mData = NULL;
	mSize = 0;
}

///////////////////////////////////////////////////////////////////////////////

RSASignature& RSASignature::operator=(const RSASignature& signature)
{
	delete [] mData;
	mSize = signature.mSize;
	mData = new uint8[mSize];
	memcpy(mData, signature.mData, mSize);
	return(*this);
}

///////////////////////////////////////////////////////////////////////////////

RSASignature RSASignature::signMessage(const RSAKey& iPrivateKey,
                                       DigestType iDigestType,
                                       const char* iMsg,
                                       size_t iMsgLen)
{
	if (!iPrivateKey.isPrivateKey())
		MIRA_THROW(XInvalidParameter,
		          "Need a private RSA key to sign a message.");

	if ((iMsg == NULL) || (iMsgLen == 0))
		MIRA_THROW(XInvalidParameter, "Can't sign an empty message.");

	////////////////////////////////////////////////////////////////
	// build the digest of the message

	size_t digestLen = 0;

	const EVP_MD* md = NULL;
	if (iDigestType == DIGEST_MD5) {
		digestLen = MD5_DIGEST_LENGTH;
		md = EVP_get_digestbyname("MD5");
	} else
	if (iDigestType == DIGEST_SHA1) {
		digestLen = SHA_DIGEST_LENGTH;
		md = EVP_get_digestbyname("SHA1");
	} else
	if (iDigestType == DIGEST_SHA256) {
		digestLen = SHA256_DIGEST_LENGTH;
		md = EVP_get_digestbyname("SHA256");
	} else
	if (iDigestType == DIGEST_SHA512) {
		digestLen = SHA512_DIGEST_LENGTH;
		md = EVP_get_digestbyname("SHA512");
	}

	if (md == NULL)
		MIRA_THROW(XSystemCall, "Unknown message digest algorithm.");

	EVP_MD_CTX* mdCtx = EVP_MD_CTX_new();
	if (mdCtx == NULL) {
		MIRA_THROW(XSystemCall, "EVP_MD_CTX_new failed:" <<
		           OpenSSLErrorString::instance().err2str(ERR_get_error()));
	}
	if (!EVP_DigestInit_ex2(mdCtx, md, NULL)) {
		auto errNo = ERR_get_error();
		EVP_MD_CTX_free(mdCtx);
		MIRA_THROW(XSystemCall, "EVP_DigestInit_ex2 failed: " <<
		           OpenSSLErrorString::instance().err2str(errNo));
	}
	if (!EVP_DigestUpdate(mdCtx, iMsg, iMsgLen)) {
		auto errNo = ERR_get_error();
		EVP_MD_CTX_free(mdCtx);
		MIRA_THROW(XSystemCall, "EVP_DigestInit_ex2 failed: " <<
		           OpenSSLErrorString::instance().err2str(errNo));
	}

	unsigned int dlen = 0;
	unsigned char digest[EVP_MAX_MD_SIZE];
	if (!EVP_DigestFinal_ex(mdCtx, digest, &dlen)) {
		auto errNo = ERR_get_error();
		EVP_MD_CTX_free(mdCtx);
		MIRA_THROW(XSystemCall, "EVP_DigestFinal_ex failed: " <<
		           OpenSSLErrorString::instance().err2str(errNo));
	}
	if (dlen != digestLen) {
		EVP_MD_CTX_free(mdCtx);
		MIRA_THROW(XLogical, "Unexpected digest size. " <<
		           "Expected length = " << digestLen <<
		           ", Digest length = " << dlen);
	}

	EVP_MD_CTX_free(mdCtx);

	////////////////////////////////////////////////////////////////
	// initialize sign context

	EVP_PKEY_CTX* ctx = EVP_PKEY_CTX_new(iPrivateKey.getOpenSSLKey()->key, NULL);
	if (!ctx) {
		MIRA_THROW(XSystemCall, "EVP_PKEY_CTX_new failed:" <<
		           OpenSSLErrorString::instance().err2str(ERR_get_error()));
	}
	if (EVP_PKEY_sign_init(ctx) <= 0) {
		auto errNo = ERR_get_error();
		EVP_PKEY_CTX_free(ctx);
		MIRA_THROW(XSystemCall, "EVP_PKEY_sign_init failed: " <<
		           OpenSSLErrorString::instance().err2str(errNo));
	}
	if (EVP_PKEY_CTX_set_rsa_padding(ctx, RSA_PKCS1_PADDING) <= 0) {
		auto errNo = ERR_get_error();
		EVP_PKEY_CTX_free(ctx);
		MIRA_THROW(XSystemCall, "EVP_PKEY_CTX_set_rsa_padding failed: " <<
		           OpenSSLErrorString::instance().err2str(errNo));
	}
	if (EVP_PKEY_CTX_set_signature_md(ctx, md) <= 0) {
		auto errNo = ERR_get_error();
		EVP_PKEY_CTX_free(ctx);
		MIRA_THROW(XSystemCall, "EVP_PKEY_CTX_set_signature_md failed: " <<
		           OpenSSLErrorString::instance().err2str(errNo));
	}

	///////////////////////////////////////////////////////////////////////////
	// now sign the message

	int keySize = EVP_PKEY_get_size(iPrivateKey.getOpenSSLKey()->key);
	if (keySize < 1) {
		EVP_PKEY_CTX_free(ctx);
		MIRA_THROW(XInvalidParameter, "Invalid private key.");
	}

	size_t sigLen = keySize;
	uint8* signature = new uint8[keySize];
	if (EVP_PKEY_sign(ctx, signature, &sigLen, digest, digestLen) <= 0) {
		auto errNo = ERR_get_error();
		delete [] signature;
		EVP_PKEY_CTX_free(ctx);
		MIRA_THROW(XSystemCall, "EVP_PKEY_sign failed: " <<
		           OpenSSLErrorString::instance().err2str(errNo));
	}

	if (keySize != (int)sigLen) {
		delete [] signature;
		MIRA_THROW(XInvalidParameter, "Unexpected RSA signature length: "
		           "Size of the signature must be equal to the size of the key "
		           << "(KeySize=" << keySize
		           << ", SignatureSize=" << sigLen << ").");
	}

	///////////////////////////////////////////////////////////////////////////

	RSASignature res;
	res.mSize = sigLen;
	res.mData = signature;

	return res;
}

bool RSASignature::verifyMessage(const RSAKey& iPublicKey,
                                 DigestType iDigestType,
                                 const char* iMsg,
                                 size_t iMsgLen,
                                 const RSASignature& iSignature)
{
	if (!iPublicKey.isPublicKey())
		MIRA_THROW(XInvalidParameter,
		           "Need a public RSA key to verify a message.");

	if ((iMsg == NULL) || (iMsgLen == 0))
		MIRA_THROW(XInvalidParameter, "Can't verify an empty message.");

	////////////////////////////////////////////////////////////////
	// build the digest of the message

	size_t digestLen = 0;

	const EVP_MD* md = NULL;
	if (iDigestType == DIGEST_MD5) {
		digestLen = MD5_DIGEST_LENGTH;
		md = EVP_get_digestbyname("MD5");
	} else
	if (iDigestType == DIGEST_SHA1) {
		digestLen = SHA_DIGEST_LENGTH;
		md = EVP_get_digestbyname("SHA1");
	} else
	if (iDigestType == DIGEST_SHA256) {
		digestLen = SHA256_DIGEST_LENGTH;
		md = EVP_get_digestbyname("SHA256");
	} else
	if (iDigestType == DIGEST_SHA512) {
		digestLen = SHA512_DIGEST_LENGTH;
		md = EVP_get_digestbyname("SHA512");
	}

	if (md == NULL)
		MIRA_THROW(XSystemCall, "Unknown message digest algorithm.");

	EVP_MD_CTX* mdCtx = EVP_MD_CTX_new();
	if (mdCtx == NULL) {
		MIRA_THROW(XSystemCall, "EVP_MD_CTX_new failed:" <<
		           OpenSSLErrorString::instance().err2str(ERR_get_error()));
	}
	if (!EVP_DigestInit_ex2(mdCtx, md, NULL)) {
		auto errNo = ERR_get_error();
		EVP_MD_CTX_free(mdCtx);
		MIRA_THROW(XSystemCall, "EVP_DigestInit_ex2 failed: " <<
		           OpenSSLErrorString::instance().err2str(errNo));
	}
	if (!EVP_DigestUpdate(mdCtx, iMsg, iMsgLen)) {
		auto errNo = ERR_get_error();
		EVP_MD_CTX_free(mdCtx);
		MIRA_THROW(XSystemCall, "EVP_DigestInit_ex2 failed: " <<
		           OpenSSLErrorString::instance().err2str(errNo));
	}

	unsigned int dlen = 0;
	unsigned char digest[EVP_MAX_MD_SIZE];
	if (!EVP_DigestFinal_ex(mdCtx, digest, &dlen)) {
		auto errNo = ERR_get_error();
		EVP_MD_CTX_free(mdCtx);
		MIRA_THROW(XSystemCall, "EVP_DigestFinal_ex failed: " <<
		           OpenSSLErrorString::instance().err2str(errNo));
	}
	if (dlen != digestLen) {
		EVP_MD_CTX_free(mdCtx);
		MIRA_THROW(XLogical, "Unexpected digest size. " <<
		           "Expected length = " << digestLen <<
		           ", Digest length = " << dlen);
	}

	EVP_MD_CTX_free(mdCtx);

	////////////////////////////////////////////////////////////////
	// initialize the verify context

	EVP_PKEY_CTX* ctx = EVP_PKEY_CTX_new(iPublicKey.getOpenSSLKey()->key, NULL);
	if (!ctx) {
		MIRA_THROW(XSystemCall, "EVP_PKEY_CTX_new failed:" <<
		           OpenSSLErrorString::instance().err2str(ERR_get_error()));
	}
	if (EVP_PKEY_verify_init(ctx) <= 0) {
		auto errNo = ERR_get_error();
		EVP_PKEY_CTX_free(ctx);
		MIRA_THROW(XSystemCall, "EVP_PKEY_verify_init failed: " <<
		           OpenSSLErrorString::instance().err2str(errNo));
	}
	if (EVP_PKEY_CTX_set_rsa_padding(ctx, RSA_PKCS1_PADDING) <= 0) {
		auto errNo = ERR_get_error();
		EVP_PKEY_CTX_free(ctx);
		MIRA_THROW(XSystemCall, "EVP_PKEY_CTX_set_rsa_padding failed: " <<
		           OpenSSLErrorString::instance().err2str(errNo));
	}
	if (EVP_PKEY_CTX_set_signature_md(ctx, md) <= 0) {
		auto errNo = ERR_get_error();
		EVP_PKEY_CTX_free(ctx);
		MIRA_THROW(XSystemCall, "EVP_PKEY_CTX_set_signature_md failed: " <<
		           OpenSSLErrorString::instance().err2str(errNo));
	}

	////////////////////////////////////////////////////////////////
	// now verify the message

	int keySize = EVP_PKEY_get_size(iPublicKey.getOpenSSLKey()->key);
	if (keySize < 1) {
		EVP_PKEY_CTX_free(ctx);
		MIRA_THROW(XInvalidParameter, "Invalid public key.");
	}

	int res = EVP_PKEY_verify(ctx, iSignature.mData, iSignature.mSize, digest, digestLen);

	auto errNo = ERR_get_error();
	EVP_PKEY_CTX_free(ctx);

	return res==1;
}

///////////////////////////////////////////////////////////////////////////////

std::ostream& operator<<(std::ostream& stream, const RSASignature& signature)
{
	std::stringstream tStr;
	for(size_t i = 0; i < signature.mSize; i++)
		tStr << boost::format("%02X") % (int)signature.mData[i];

	stream << tStr.str();
	return stream;
}

std::istream& operator>>(std::istream& stream, RSASignature& signature)
{
	// read hex string from stream
	std::string s;
	stream >> s;

	if ((s.size() % 2) != 0)
		MIRA_THROW(XInvalidParameter, "Need a even number of chars in "
		          "hex-string to convert to a RSA signature");

	size_t len = s.size()/2;
	uint8* data = new uint8[len];
	const char* srcPtr = s.data();
	for(size_t i = 0; i < len; i++, srcPtr += 2) {
		int v = 0;
		sscanf(srcPtr, "%02x", &v);
		data[i] = v;
	}

	// store into signature and cleanup
	delete [] signature.mData;
	signature.mData = data;
	signature.mSize = len;

	return stream;
}

///////////////////////////////////////////////////////////////////////////////

} // namespaces
