/*
 * Copyright (C) 2012 by
 *   MetraLabs GmbH (MLAB), GERMANY
 * and
 *   Neuroinformatics and Cognitive Robotics Labs (NICR) at TU Ilmenau, GERMANY
 * All rights reserved.
 *
 * Contact: info@mira-project.org
 *
 * Commercial Usage:
 *   Licensees holding valid commercial licenses may use this file in
 *   accordance with the commercial license agreement provided with the
 *   software or, alternatively, in accordance with the terms contained in
 *   a written agreement between you and MLAB or NICR.
 *
 * GNU General Public License Usage:
 *   Alternatively, this file may be used under the terms of the GNU
 *   General Public License version 3.0 as published by the Free Software
 *   Foundation and appearing in the file LICENSE.GPL3 included in the
 *   packaging of this file. Please review the following information to
 *   ensure the GNU General Public License version 3.0 requirements will be
 *   met: http://www.gnu.org/copyleft/gpl.html.
 *   Alternatively you may (at your option) use any later version of the GNU
 *   General Public License if such license has been publicly approved by
 *   MLAB and NICR (or its successors, if any).
 *
 * IN NO EVENT SHALL "MLAB" OR "NICR" BE LIABLE TO ANY PARTY FOR DIRECT,
 * INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF
 * THE USE OF THIS SOFTWARE AND ITS DOCUMENTATION, EVEN IF "MLAB" OR
 * "NICR" HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 * "MLAB" AND "NICR" SPECIFICALLY DISCLAIM ANY WARRANTIES, INCLUDING,
 * BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
 * FITNESS FOR A PARTICULAR PURPOSE. THE SOFTWARE PROVIDED HEREUNDER IS
 * ON AN "AS IS" BASIS, AND "MLAB" AND "NICR" HAVE NO OBLIGATION TO
 * PROVIDE MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS OR MODIFICATIONS.
 */

/**
 * @file RSAKey.C
 *    Implementation of RSAKey.h
 *
 * @author Christian Martin
 * @date   2010/09/10
 */

#include <security/RSAKey.h>

#include <boost/format.hpp>
#include <boost/random.hpp>

#include <utils/StringAlgorithms.h>
#include <utils/Time.h>

#include <error/Exceptions.h>

#include "OpenSSLHelper.h"

using namespace std;

namespace mira {

///////////////////////////////////////////////////////////////////////////////
// Implementation of RSAKey

RSAKey::RSAKey()
{
	OpenSSLCleanup::instance();

	mKey = new OpenSSLRSAWrapper();
	mKey->key = RSA_new();
}

RSAKey::RSAKey(const RSAKey& key)
{
	OpenSSLCleanup::instance();

	mKey = new OpenSSLRSAWrapper();

	ERR_clear_error();

	if (key.mKey->key->d != NULL)
		mKey->key = RSAPrivateKey_dup(key.mKey->key);
	else
		mKey->key = RSAPublicKey_dup(key.mKey->key);

	if (mKey->key == NULL) {
		unsigned long tErrNo = ERR_get_error();
		MIRA_THROW(XSystemCall, "Failed to duplicate RSA key: " <<
		           OpenSSLErrorString::instance().err2str(tErrNo));
	}
}

RSAKey::RSAKey(const string& n, const string& e, const string& d)
{
	OpenSSLCleanup::instance();

	mKey = new OpenSSLRSAWrapper();
	mKey->key = RSA_new();

	if (n.size() > 0) {
		mKey->key->n = BN_new();
		if (BN_hex2bn(&mKey->key->n, n.c_str()) == 0)
			BN_zero(mKey->key->n);
	}
	if (e.size() > 0) {
		mKey->key->e = BN_new();
		if (BN_hex2bn(&mKey->key->e, e.c_str()) == 0)
			BN_zero(mKey->key->e);
	}
	if (d.size() > 0) {
		mKey->key->d = BN_new();
		if (BN_hex2bn(&mKey->key->d, d.c_str()) == 0)
			BN_zero(mKey->key->d);
	}
}

RSAKey::~RSAKey()
{
	RSA_free(mKey->key);
	delete mKey;
	mKey = NULL;
}

///////////////////////////////////////////////////////////////////////////////

RSAKey::RSAKey(const OpenSSLRSAWrapper* key)
{
	mKey = new OpenSSLRSAWrapper();

	if (key == NULL)
		MIRA_THROW(XInvalidParameter, "Key must not be NULL.");

	ERR_clear_error();

	if (key->key->d != NULL)
		mKey->key = RSAPrivateKey_dup(key->key);
	else
		mKey->key = RSAPublicKey_dup(key->key);

	if (mKey->key == NULL) {
		unsigned long tErrNo = ERR_get_error();
		MIRA_THROW(XSystemCall, "Failed to duplicate RSA key: " <<
		           OpenSSLErrorString::instance().err2str(tErrNo));
	}
}

///////////////////////////////////////////////////////////////////////////////

bool RSAKey::isValid() const
{
	return((mKey->key->n != NULL) && (!BN_is_zero(mKey->key->n)) &&
	       (((mKey->key->e != NULL) && (!BN_is_zero(mKey->key->e))) ||
	       ((mKey->key->d != NULL) && (!BN_is_zero(mKey->key->d)))));
}

bool RSAKey::isPublicKey() const
{
	return((mKey->key->n != NULL) && (!BN_is_zero(mKey->key->n)) &&
	       (mKey->key->e != NULL) && (!BN_is_zero(mKey->key->e)));
}

bool RSAKey::isPrivateKey() const
{
	return((mKey->key->n != NULL) && (!BN_is_zero(mKey->key->n)) &&
	       (mKey->key->d != NULL) && (!BN_is_zero(mKey->key->d)));
}

bool RSAKey::clear()
{
	if (mKey->key->n != NULL)
		BN_zero(mKey->key->n);
	if (mKey->key->e != NULL)
		BN_zero(mKey->key->e);
	if (mKey->key->d != NULL)
		BN_zero(mKey->key->d);

	return(true);
}

///////////////////////////////////////////////////////////////////////////////

RSAKey& RSAKey::operator=(const RSAKey& key)
{
	RSA_free(mKey->key);

	ERR_clear_error();

	if (key.mKey->key->d != NULL)
		mKey->key = RSAPrivateKey_dup(key.mKey->key);
	else
		mKey->key = RSAPublicKey_dup(key.mKey->key);

	if (mKey->key == NULL) {
		unsigned long tErrNo = ERR_get_error();
		MIRA_THROW(XSystemCall, "Failed to duplicate RSA key: " <<
		           OpenSSLErrorString::instance().err2str(tErrNo));
	}

	return(*this);
}

bool RSAKey::operator==(const RSAKey& key)
{
	if (!(isValid() && key.isValid()))
		return(false);

	return((BN_cmp(mKey->key->n, key.mKey->key->n) == 0) &&
	       (BN_cmp(mKey->key->e, key.mKey->key->e) == 0) &&
	       (BN_cmp(mKey->key->d, key.mKey->key->d) == 0));
}

bool RSAKey::operator!=(const RSAKey& key)
{
	return(!(*this == key));
}

///////////////////////////////////////////////////////////////////////////////

string RSAKey::getNStr() const
{
	if (mKey->key->n != NULL) {
		char* tStr = BN_bn2hex(mKey->key->n);
		if (tStr == NULL)
			MIRA_THROW(XSystemCall, "Unable to convert BIGNUM n to string.");
		string tRes(tStr);
		OPENSSL_free(tStr);
		return(tRes);
	} else
		return("");
}

string RSAKey::getEStr() const
{
	if (mKey->key->e != NULL) {
		char* tStr = BN_bn2hex(mKey->key->e);
		if (tStr == NULL)
			MIRA_THROW(XSystemCall, "Unable to convert BIGNUM e to string.");
		string tRes(tStr);
		OPENSSL_free(tStr);
		return(tRes);
	} else
		return("");
}

string RSAKey::getDStr() const
{
	if (mKey->key->d != NULL) {
		char* tStr = BN_bn2hex(mKey->key->d);
		if (tStr == NULL)
			MIRA_THROW(XSystemCall, "Unable to convert BIGNUM d to string.");
		string tRes(tStr);
		OPENSSL_free(tStr);
		return(tRes);
	} else
		return("");
}

///////////////////////////////////////////////////////////////////////////////

void RSAKey::generateKey(unsigned int iKeyBitLength,
                         RSAKey &oPublicKey, RSAKey &oPrivateKey)
{
	// We should use at least 128 bits!
	if (iKeyBitLength < 128)
		MIRA_THROW(XInvalidParameter,
		          "Key bit length should be at least 128 bits.");

	feedRandomNumberGenerator(iKeyBitLength);
	ERR_clear_error();
	RSA* tKey = RSA_generate_key(iKeyBitLength, RSA_F4, NULL, NULL);
	if (tKey == NULL) {
		unsigned long tErrNo = ERR_get_error();
		MIRA_THROW(XSystemCall, "RSA_generate_key failed: " <<
		           OpenSSLErrorString::instance().err2str(tErrNo));
	}

	// duplicate the components and split up into a public and a private key
	OpenSSLRSAWrapper tPublicKey, tPrivateKey;
	tPublicKey.key = RSAPublicKey_dup(tKey);
	tPrivateKey.key = RSAPrivateKey_dup(tKey);

	// copy the key into return values
	oPublicKey = RSAKey(&tPublicKey);
	oPrivateKey = RSAKey(&tPrivateKey);

	// cleanup memory
	RSA_free(tPublicKey.key);
	RSA_free(tPrivateKey.key);
	RSA_free(tKey);
}

void RSAKey::feedRandomNumberGenerator(size_t count)
{
	boost::mt19937 tGenerator(Time::now().toUnixTimestamp());
	boost::uniform_int<> t256(0, 255);
	boost::variate_generator<boost::mt19937&, boost::uniform_int<> >
		tRand(tGenerator, t256);

	// Feed OpenSSL's pseudo random number generator with interesting data :-)
	uint8* tRandData = new uint8[count];
	for(unsigned int i = 0; i < count; i++)
		tRandData[i] = (uint8)tRand();
	RAND_seed(tRandData, count);
	delete [] tRandData;
}

///////////////////////////////////////////////////////////////////////////////

ostream& operator<<(ostream& stream, const RSAKey& key)
{
	int tRequiredLen = 0;
	unsigned char* tBuffer = NULL;

	if (key.isPrivateKey()) {
		// Determine memory size for the key
		tRequiredLen = i2d_RSAPrivateKey(key.mKey->key, NULL);
		tBuffer = (unsigned char*)OPENSSL_malloc(tRequiredLen);
		memset(tBuffer, 0x00, tRequiredLen);

		// Since i2d_RSAPrivateKey will modify the pointer, we must store it.
		unsigned char* tPtr = tBuffer;

		ERR_clear_error();
		if (i2d_RSAPrivateKey(key.mKey->key, &tPtr) == 0) {
			unsigned long tErrNo = ERR_get_error();
			OPENSSL_free(tBuffer);
			MIRA_THROW(XSystemCall, "i2d_RSAPrivateKey failed: " <<
			           OpenSSLErrorString::instance().err2str(tErrNo));
		}

		stream << "PRIVATE:";
	} else {
		// Determine memory size for the key
		tRequiredLen = i2d_RSAPublicKey(key.mKey->key, NULL);
		tBuffer = (unsigned char*)OPENSSL_malloc(tRequiredLen);
		memset(tBuffer, 0x00, tRequiredLen);

		// Since i2d_RSAPublicKey will modify the pointer, we must store it.
		unsigned char* tPtr = tBuffer;

		ERR_clear_error();
		if (i2d_RSAPublicKey(key.mKey->key, &tPtr) == 0) {
			unsigned long tErrNo = ERR_get_error();
			OPENSSL_free(tBuffer);
			MIRA_THROW(XSystemCall, "i2d_RSAPublicKey failed: " <<
			           OpenSSLErrorString::instance().err2str(tErrNo));
		}

		stream << "PUBLIC:";
	}

	// Store key in the stream: "Len:Data;"
	stream << tRequiredLen << ":";
	for(int i = 0; i < tRequiredLen; i++)
		stream << boost::format("%02x") % (int)tBuffer[i];
	stream << ";";

	OPENSSL_free(tBuffer);

	return(stream);
}

istream& operator>>(istream& stream, RSAKey& key)
{
	string tIn;
	stream >> tIn;

	// Expected format: {PUBLIC|PRIVATE}:Len:HexData;
	vector<string> tParts;
	boost::algorithm::split(tParts, tIn, boost::is_from_range(':',':'));

	///////////////////////////////////////////////////////////////////////////
	// check content of the stream

	if (tParts.size() != 3) {
		MIRA_THROW(XInvalidParameter,
		           "Unexpected stream data. "
		           "Format should be: {PUBLIC|PRIVATE}:Len:HexData;");
	}
	if ((tParts[0] != "PUBLIC") && (tParts[0] != "PRIVATE")) {
		MIRA_THROW(XInvalidParameter,
		           "Unexpected stream data. "
		           "Key type must be PUBLIC or PRIVATE.");
	}
	bool tIsPrivate = (tParts[0] == "PRIVATE");

	int tLen = boost::lexical_cast<int>(tParts[1]);
	if ((int)tParts[2].size() != (2*tLen+1)) {
		MIRA_THROW(XInvalidParameter,
		           "Unexpected stream data. Invalid number of data bytes.");
	}

	///////////////////////////////////////////////////////////////////////////
	// convert string data into binary buffer

	unsigned char* tBuffer = (unsigned char*)OPENSSL_malloc(tLen);
	for(int i = 0; i < tLen; i++) {
		istringstream tIn(tParts[2].substr(2*i, 2));
		int tValue = 0;
		tIn >> hex >> tValue;
		tBuffer[i] = tValue;
	}

	///////////////////////////////////////////////////////////////////////////
	// convert binary data in RSA key structure

	ERR_clear_error();
	const unsigned char* tPtr = tBuffer;
	if (tIsPrivate) {
		if (d2i_RSAPrivateKey(&key.mKey->key, &tPtr, tLen) == NULL) {
			unsigned long tErrNo = ERR_get_error();
			OPENSSL_free(tBuffer);
			MIRA_THROW(XSystemCall, "d2i_RSAPrivateKey failed: " <<
			           OpenSSLErrorString::instance().err2str(tErrNo));
		}
	} else {
		if (d2i_RSAPublicKey(&key.mKey->key, &tPtr, tLen) == NULL) {
			unsigned long tErrNo = ERR_get_error();
			OPENSSL_free(tBuffer);
			MIRA_THROW(XSystemCall, "d2i_RSAPublicKey failed: " <<
			           OpenSSLErrorString::instance().err2str(tErrNo));
		}
	}

	///////////////////////////////////////////////////////////////////////////
	// cleanup memory and finish

	OPENSSL_free(tBuffer);

	return(stream);
}

///////////////////////////////////////////////////////////////////////////////

} // namespace
