/*
 * Copyright (C) 2012 by
 *   MetraLabs GmbH (MLAB), GERMANY
 * and
 *   Neuroinformatics and Cognitive Robotics Labs (NICR) at TU Ilmenau, GERMANY
 * All rights reserved.
 *
 * Contact: info@mira-project.org
 *
 * Commercial Usage:
 *   Licensees holding valid commercial licenses may use this file in
 *   accordance with the commercial license agreement provided with the
 *   software or, alternatively, in accordance with the terms contained in
 *   a written agreement between you and MLAB or NICR.
 *
 * GNU General Public License Usage:
 *   Alternatively, this file may be used under the terms of the GNU
 *   General Public License version 3.0 as published by the Free Software
 *   Foundation and appearing in the file LICENSE.GPL3 included in the
 *   packaging of this file. Please review the following information to
 *   ensure the GNU General Public License version 3.0 requirements will be
 *   met: http://www.gnu.org/copyleft/gpl.html.
 *   Alternatively you may (at your option) use any later version of the GNU
 *   General Public License if such license has been publicly approved by
 *   MLAB and NICR (or its successors, if any).
 *
 * IN NO EVENT SHALL "MLAB" OR "NICR" BE LIABLE TO ANY PARTY FOR DIRECT,
 * INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF
 * THE USE OF THIS SOFTWARE AND ITS DOCUMENTATION, EVEN IF "MLAB" OR
 * "NICR" HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 * "MLAB" AND "NICR" SPECIFICALLY DISCLAIM ANY WARRANTIES, INCLUDING,
 * BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
 * FITNESS FOR A PARTICULAR PURPOSE. THE SOFTWARE PROVIDED HEREUNDER IS
 * ON AN "AS IS" BASIS, AND "MLAB" AND "NICR" HAVE NO OBLIGATION TO
 * PROVIDE MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS OR MODIFICATIONS.
 */

/**
 * @file AESFilter.C
 *    Implementation of AESFilter.h
 *
 * @author Christian Martin
 * @date   2012/06/17
 */

#include <security/AESFilter.h>

#include <boost/asio/buffers_iterator.hpp>

#include <error/Exceptions.h>

#include "OpenSSLHelper.h"

using namespace std;

namespace mira {

///////////////////////////////////////////////////////////////////////////////

namespace Private {

///////////////////////////////////////////////////////////////////////////////

struct AESFilterBase::Context {
	EVP_CIPHER_CTX* ctx;
	std::string     key;
};

///////////////////////////////////////////////////////////////////////////////

AESFilterBase::AESFilterBase() :
	mEncryptMode(true),
	mFinalCalled(false),
	mWrkBuffer(NULL)
{
	mCtx = new Context();
}

AESFilterBase::~AESFilterBase()
{
	delete mCtx;
	mCtx = NULL;

	delete [] mWrkBuffer;
	mWrkBuffer = NULL;
}

bool AESFilterBase::encrypt(const char*& ioSrcBegin,
                            const char*  iSrcEnd,
                            char*&       ioDestBegin,
                            char*        iDestEnd,
                            bool         iFlush)
{
	///////////////////////////////////////////////////////////////////////////////
	// store new incoming encrypted data into mInputBuffer

	if (iSrcEnd != ioSrcBegin) {
		size_t tInputBlockSize = iSrcEnd - ioSrcBegin;
		mInputBuffer.sputn(ioSrcBegin, tInputBlockSize);
		ioSrcBegin = iSrcEnd;
	}

	///////////////////////////////////////////////////////////////////////////
	// as long as we have data in the input buffer, we encrypt the data

	if (mInputBuffer.size() > 0) {

		if (mFinalCalled)
			MIRA_THROW(XRuntime, "EVP_EncryptFinal_ex already called.");

		// encrypt not more than 64kByte at once
		size_t tInputSize = mInputBuffer.size();
		if (tInputSize > 65536)
			tInputSize = 65536;

		// ensure, that we have enough space in the working output buffer
		size_t tMaxCipherLen = tInputSize + AES_BLOCK_SIZE - 1;
		if (tMaxCipherLen > mWrkBufferSize) {
			mWrkBufferSize = tMaxCipherLen;
			if (mWrkBuffer)
				delete [] mWrkBuffer;
			mWrkBuffer = new uint8[mWrkBufferSize];
		}

		// get the plain data from the input buffer
		const char *tDataPlain =
			&(*(boost::asio::buffers_begin(mInputBuffer.data())));

		// encrypt the data
		int tOutputLen = 0;
		if (!EVP_EncryptUpdate(mCtx->ctx, mWrkBuffer, &tOutputLen,
		                       (const unsigned char*)tDataPlain, tInputSize))
		{
			unsigned long tErrNo = ERR_get_error();
			MIRA_THROW(XSystemCall, "EVP_EncryptUpdate failed: " <<
			           OpenSSLErrorString::instance().err2str(tErrNo));
		}

		// remove data from input buffer
		mInputBuffer.consume(tInputSize);

		// append data to output buffer
		mOutputBuffer.sputn((const char*)mWrkBuffer, tOutputLen);
	}

	///////////////////////////////////////////////////////////////////////////
	// handle the remaining bytes

	if (iFlush && (mInputBuffer.size() == 0) && !mFinalCalled) {
		// encrypt the remaining block
		int tOutputLen = 0;
		if (!EVP_EncryptFinal_ex(mCtx->ctx, mWrkBuffer, &tOutputLen))
		{
			unsigned long tErrNo = ERR_get_error();
			MIRA_THROW(XSystemCall, "EVP_EncryptFinal_ex failed: " <<
			           OpenSSLErrorString::instance().err2str(tErrNo));
		}

		// append data to output buffer
		mOutputBuffer.sputn((const char*)mWrkBuffer, tOutputLen);

		mFinalCalled = true;
	}

	///////////////////////////////////////////////////////////////////////////
	// output as much data as possible

	size_t tSpaceInBuffer = iDestEnd - ioDestBegin;
	const char *tDataPlain =
		&(*(boost::asio::buffers_begin(mOutputBuffer.data())));

	size_t tCopySize = mOutputBuffer.size();
	if (tCopySize > tSpaceInBuffer)
		tCopySize = tSpaceInBuffer;

	memcpy(ioDestBegin, tDataPlain, tCopySize);

	mOutputBuffer.consume(tCopySize);
	ioDestBegin += tCopySize;

	bool tCallAgain =
			(mOutputBuffer.size() > 0) ||
			(mInputBuffer.size() > 0) ||
			!mFinalCalled;
	return(tCallAgain);
}

bool AESFilterBase::decrypt(const char*& ioSrcBegin,
                            const char*  iSrcEnd,
                            char*&       ioDestBegin,
                            char*        iDestEnd,
                            bool         iFlush)
{
	///////////////////////////////////////////////////////////////////////////
	// store new incoming encrypted data into mInputBuffer

	if (iSrcEnd != ioSrcBegin) {
		size_t tInputBlockSize = iSrcEnd - ioSrcBegin;
		mInputBuffer.sputn(ioSrcBegin, tInputBlockSize);
		ioSrcBegin = iSrcEnd;
	}

	///////////////////////////////////////////////////////////////////////////
	// if the input buffer has enough data or flush is true encrypt the data

	if (mInputBuffer.size() > 0) {
		if (mFinalCalled)
			MIRA_THROW(XRuntime, "EVP_DecryptFinal_ex already called.");

		// decrypt not more than 64kByte at once
		size_t tInputSize = mInputBuffer.size();
		if (tInputSize > 65536)
			tInputSize = 65536;

		// ensure, that we have enough space in the buffer
		if (tInputSize > mWrkBufferSize) {
			mWrkBufferSize = tInputSize;
			if (mWrkBuffer)
				delete [] mWrkBuffer;
			mWrkBuffer = new uint8[mWrkBufferSize + AES_BLOCK_SIZE];
		}

		// get the encrypted data
		const char *tDataPlain =
			&(*(boost::asio::buffers_begin(mInputBuffer.data())));

		// decrypt the data
		int tOutputLen = 0;
		if (!EVP_DecryptUpdate(mCtx->ctx, mWrkBuffer, &tOutputLen,
		                       (const unsigned char*)tDataPlain, tInputSize))
		{
			unsigned long tErrNo = ERR_get_error();
			MIRA_THROW(XSystemCall, "EVP_DecryptUpdate failed: " <<
			           OpenSSLErrorString::instance().err2str(tErrNo));
		}

		// remove data from input buffer
		mInputBuffer.consume(tInputSize);

		// append data to output buffer
		mOutputBuffer.sputn((const char*)mWrkBuffer, tOutputLen);
	}

	if (iFlush && (mInputBuffer.size() == 0) && !mFinalCalled) {

		// decrypt the remaining block
		int tOutputLen = 0;
		if (!EVP_DecryptFinal_ex(mCtx->ctx, mWrkBuffer, &tOutputLen))
		{
			unsigned long tErrNo = ERR_get_error();
			MIRA_THROW(XSystemCall, "EVP_DecryptFinal_ex failed: " <<
			           OpenSSLErrorString::instance().err2str(tErrNo));
		}

		// append data to output buffer
		mOutputBuffer.sputn((const char*)mWrkBuffer, tOutputLen);

		mFinalCalled = true;
	}

	///////////////////////////////////////////////////////////////////////////
	// output as much data as possible

	size_t tSpaceInBuffer = iDestEnd - ioDestBegin;
	const char *tDataPlain =
		&(*(boost::asio::buffers_begin(mOutputBuffer.data())));

	size_t tCopySize = mOutputBuffer.size();
	if (tCopySize > tSpaceInBuffer)
		tCopySize = tSpaceInBuffer;

	memcpy(ioDestBegin, tDataPlain, tCopySize);

	mOutputBuffer.consume(tCopySize);
	ioDestBegin += tCopySize;

	bool tCallAgain =
			(mOutputBuffer.size() > 0) ||
			(mInputBuffer.size() > 0) ||
			!mFinalCalled;
	return(tCallAgain);
}

void AESFilterBase::reset()
{
	// Create an initial work buffer of 128bit (= AES_BLOCK_SIZE)
	mWrkBufferSize = AES_BLOCK_SIZE;
	if (mWrkBuffer != NULL)
		delete [] mWrkBuffer;
	mWrkBuffer = new uint8[mWrkBufferSize];

	// Cleanup all working buffers
	mInputBuffer.consume(mInputBuffer.size());
	mOutputBuffer.consume(mOutputBuffer.size());

	///////////////////////////////////////////////////////////////////////////
	// Find the matching encryption/decryption function

	static const struct {
		AESBitLength       mode;
		AESBlockCipherMode block_cipher;
		const EVP_CIPHER*  func;
		int                key_length;
	} sAESFuncList[] = {
		{ AES_128, AES_ECB,    EVP_aes_128_ecb(),    128 },
		{ AES_128, AES_CBC,    EVP_aes_128_cbc(),    128 },
		{ AES_128, AES_CFB1,   EVP_aes_128_cfb1(),   128 },
		{ AES_128, AES_CFB8,   EVP_aes_128_cfb8(),   128 },
		{ AES_128, AES_CFB128, EVP_aes_128_cfb128(), 128 },
		{ AES_128, AES_OFB,    EVP_aes_128_ofb(),    128 },

		{ AES_192, AES_ECB,    EVP_aes_192_ecb(),    192 },
		{ AES_192, AES_CBC,    EVP_aes_192_cbc(),    192 },
		{ AES_192, AES_CFB1,   EVP_aes_192_cfb1(),   192 },
		{ AES_192, AES_CFB8,   EVP_aes_192_cfb8(),   192 },
		{ AES_192, AES_CFB128, EVP_aes_192_cfb128(), 192 },
		{ AES_192, AES_OFB,    EVP_aes_192_ofb(),    192 },

		{ AES_256, AES_ECB,    EVP_aes_256_ecb(),    256 },
		{ AES_256, AES_CBC,    EVP_aes_256_cbc(),    256 },
		{ AES_256, AES_CFB1,   EVP_aes_256_cfb1(),   256 },
		{ AES_256, AES_CFB8,   EVP_aes_256_cfb8(),   256 },
		{ AES_256, AES_CFB128, EVP_aes_256_cfb128(), 256 },
		{ AES_256, AES_OFB,    EVP_aes_256_ofb(),    256 },

		{ (AESBitLength)0, (AESBlockCipherMode)0, NULL,0 }
	};

	const EVP_CIPHER* aesFunc = NULL;
	int keyBitLength = 0;
	int i = 0;
	while(sAESFuncList[i].func != NULL) {
		if ((sAESFuncList[i].mode == mCfg.bitLength) &&
			(sAESFuncList[i].block_cipher == mCfg.blockCipherMode))
		{
			aesFunc = sAESFuncList[i].func;
			keyBitLength = sAESFuncList[i].key_length;
			break;
		}
		i++;
	}
	if (aesFunc == NULL)
		MIRA_THROW(XInvalidParameter,
		           "Can't find AES function for given configuration.");

	///////////////////////////////////////////////////////////////////////////
	// Message digest algorithm

	const EVP_MD* mdFunc = NULL;
	if (mCfg.mdAlgo == AES_MD_MD5)
		mdFunc = EVP_md5();
	else if (mCfg.mdAlgo == AES_MD_SHA1)
		mdFunc = EVP_sha1();
	else if (mCfg.mdAlgo == AES_MD_SHA256)
		mdFunc = EVP_sha256();
	else if (mCfg.mdAlgo == AES_MD_SHA512)
		mdFunc = EVP_sha512();

	if (mdFunc == NULL)
		MIRA_THROW(XInvalidParameter,
		           "Can't find message digest algorithm for given configuration.");

	///////////////////////////////////////////////////////////////////////////
	// Salt configuration

	const unsigned char* saltPtr = NULL;

	if (mCfg.salt.size() > 0) {
		if (mCfg.salt.size() != 8)
			MIRA_THROW(XInvalidParameter, "Salt must have length 8 or zero.");
		saltPtr = (const unsigned char*)mCfg.salt.data();
	}

	///////////////////////////////////////////////////////////////////////////
	// Generate key and IV for AES

	unsigned char key[32], iv[32];
	int tKeyLen = EVP_BytesToKey(aesFunc, mdFunc, saltPtr,
	                             (const unsigned char*)mCtx->key.data(),
	                             mCtx->key.size(), mCfg.nrRounds, key, iv);
	if (8*tKeyLen != keyBitLength)  {
		MIRA_THROW(XSystemCall, "AES_init failed: Key size is " << 8*tKeyLen <<
		           " bits - should be " << keyBitLength << " bits");
	}

	///////////////////////////////////////////////////////////////////////////
	// Initialize encryption/decryption

	if (mEncryptMode) {
		if (EVP_EncryptInit_ex(mCtx->ctx, aesFunc, NULL, key, iv) != 1) {
			unsigned long tErrNo = ERR_get_error();
			MIRA_THROW(XSystemCall, "EVP_EncryptInit_ex failed: " <<
			           OpenSSLErrorString::instance().err2str(tErrNo));
		}
	} else {
		if (EVP_DecryptInit_ex(mCtx->ctx, aesFunc, NULL, key, iv) != 1) {
			unsigned long tErrNo = ERR_get_error();
			MIRA_THROW(XSystemCall, "EVP_DecryptInit_ex failed: " <<
			           OpenSSLErrorString::instance().err2str(tErrNo));
		}
	}

	mFinalCalled = false;
}

void AESFilterBase::initFilter(const AESConfiguration& cfg,
                               const std::string& key,
                               bool encrypt, void* alloc)
{
	mCfg = cfg;
	mEncryptMode = encrypt;

	mCtx->ctx = EVP_CIPHER_CTX_new();
	mCtx->key = key;

	reset();
}

///////////////////////////////////////////////////////////////////////////////

} // end of namespace Private

///////////////////////////////////////////////////////////////////////////////

} // namespace
