/*
 * Copyright (C) by
 *   MetraLabs GmbH (MLAB), GERMANY
 * and
 *   Neuroinformatics and Cognitive Robotics Labs (NICR) at TU Ilmenau, GERMANY
 * All rights reserved.
 *
 * Contact: info@mira-project.org
 *
 * Commercial Usage:
 *   Licensees holding valid commercial licenses may use this file in
 *   accordance with the commercial license agreement provided with the
 *   software or, alternatively, in accordance with the terms contained in
 *   a written agreement between you and MLAB or NICR.
 *
 * GNU General Public License Usage:
 *   Alternatively, this file may be used under the terms of the GNU
 *   General Public License version 3.0 as published by the Free Software
 *   Foundation and appearing in the file LICENSE.GPL3 included in the
 *   packaging of this file. Please review the following information to
 *   ensure the GNU General Public License version 3.0 requirements will be
 *   met: http://www.gnu.org/copyleft/gpl.html.
 *   Alternatively you may (at your option) use any later version of the GNU
 *   General Public License if such license has been publicly approved by
 *   MLAB and NICR (or its successors, if any).
 *
 * IN NO EVENT SHALL "MLAB" OR "NICR" BE LIABLE TO ANY PARTY FOR DIRECT,
 * INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF
 * THE USE OF THIS SOFTWARE AND ITS DOCUMENTATION, EVEN IF "MLAB" OR
 * "NICR" HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 * "MLAB" AND "NICR" SPECIFICALLY DISCLAIM ANY WARRANTIES, INCLUDING,
 * BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
 * FITNESS FOR A PARTICULAR PURPOSE. THE SOFTWARE PROVIDED HEREUNDER IS
 * ON AN "AS IS" BASIS, AND "MLAB" AND "NICR" HAVE NO OBLIGATION TO
 * PROVIDE MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS OR MODIFICATIONS.
 */

/**
 * @file RSAFilter.C
 *    Implementation of RSAFilter.h
 *
 * @author Christian Martin
 * @date   2023/12/xx
 */

#include <security/RSAFilter.h>

#include <boost/asio/buffers_iterator.hpp>

#include <error/Exceptions.h>

#include "../OpenSSLHelper.h"

using namespace std;

namespace mira {

///////////////////////////////////////////////////////////////////////////////

namespace Private {

///////////////////////////////////////////////////////////////////////////////

struct RSAFilterBase::KeyCtx
{
	EVP_PKEY_CTX* encrypt;
	EVP_PKEY_CTX* decrypt;

	KeyCtx() : encrypt(NULL), decrypt (NULL) {}
};

///////////////////////////////////////////////////////////////////////////////

RSAFilterBase::RSAFilterBase() :
	mWrkBuffer(NULL)
{
	mCtx = new KeyCtx;
}

RSAFilterBase::~RSAFilterBase()
{
	if (mCtx->encrypt) {
		EVP_PKEY_CTX_free(mCtx->encrypt);
		mCtx->encrypt = NULL;
	}
	if (mCtx->decrypt) {
		EVP_PKEY_CTX_free(mCtx->decrypt);
		mCtx->decrypt = NULL;
	}

	delete mCtx;
	mCtx = NULL;

	delete [] mWrkBuffer;
	mWrkBuffer = NULL;
}

bool RSAFilterBase::encryptPublic(const char*& ioSrcBegin,
                                  const char*  iSrcEnd,
                                  char*&       ioDestBegin,
                                  char*        iDestEnd,
                                  bool         iFlush)
{
	////////////////////////////////////////////////////////////////
	// initialize encrypt context if necessary

	if (!mCtx->encrypt) {
		mCtx->encrypt = EVP_PKEY_CTX_new(mKey.getOpenSSLKey()->key, NULL);
		if (!mCtx->encrypt) {
			MIRA_THROW(XSystemCall, "EVP_PKEY_CTX_new failed:" <<
			           OpenSSLErrorString::instance().err2str(ERR_get_error()));
		}

		if (EVP_PKEY_encrypt_init(mCtx->encrypt) <= 0) {
			auto errNo = ERR_get_error();
			EVP_PKEY_CTX_free(mCtx->encrypt);
			mCtx->encrypt = NULL;
			MIRA_THROW(XSystemCall, "EVP_PKEY_encrypt_init failed: " <<
			           OpenSSLErrorString::instance().err2str(errNo));
		}

		if (EVP_PKEY_CTX_set_rsa_padding(mCtx->encrypt, RSA_PKCS1_PADDING) <= 0) {
			auto errNo = ERR_get_error();
			EVP_PKEY_CTX_free(mCtx->encrypt);
			mCtx->encrypt = NULL;
			MIRA_THROW(XSystemCall, "EVP_PKEY_CTX_set_rsa_padding failed: " <<
			           OpenSSLErrorString::instance().err2str(errNo));
		}
	}

	////////////////////////////////////////////////////////////////
	// store new incoming unencrypted data into mInputBuffer

	if (iSrcEnd != ioSrcBegin) {
		size_t inputBlockSize = iSrcEnd - ioSrcBegin;
		mInputBuffer.sputn(ioSrcBegin, inputBlockSize);
		ioSrcBegin = iSrcEnd;
	}

	////////////////////////////////////////////////////////////////
	// if the input buffer has enough data or flush is true
	// we can encrypt the data

	if (iFlush || (mInputBuffer.size() >= mBlockSize)) {
		// Feed the random number generator
		RSAKey::feedRandomNumberGenerator(8*mRSASize);

		size_t inBlockLen = mInputBuffer.size();
		if (inBlockLen > mBlockSize)
			inBlockLen = mBlockSize;

		const unsigned char* inDataPtr = (const unsigned char*)
			&(*(boost::asio::buffers_begin(mInputBuffer.data())));

		size_t outSize = mRSASize;
		if (EVP_PKEY_encrypt(mCtx->encrypt,
		                     mWrkBuffer, &outSize,
		                     inDataPtr, inBlockLen) <= 0)
		{
			auto errNo = ERR_get_error();
			MIRA_THROW(XSystemCall, "EVP_PKEY_encrypt failed: " <<
			           OpenSSLErrorString::instance().err2str(errNo));
		}

		// remove unencrypted data from input buffer
		mInputBuffer.consume(inBlockLen);

		// append encrypted data to output buffer
		mOutputBuffer.sputn((const char*)mWrkBuffer, outSize);
	}

	////////////////////////////////////////////////////////////////
	// output as much data as possible

	size_t spaceInDestBuffer = iDestEnd - ioDestBegin;
	size_t copySize = std::min(mOutputBuffer.size(), spaceInDestBuffer);

	const char *outBufferPtr =
		&(*(boost::asio::buffers_begin(mOutputBuffer.data())));

	memcpy(ioDestBegin, outBufferPtr, copySize);

	mOutputBuffer.consume(copySize);
	ioDestBegin += copySize;

	bool callAgain = (mOutputBuffer.size() > 0) || (mInputBuffer.size() > 0);
	return callAgain;
}

bool RSAFilterBase::decryptPrivate(const char*& ioSrcBegin,
                                   const char*  iSrcEnd,
                                   char*&       ioDestBegin,
                                   char*        iDestEnd,
                                   bool         iFlush)
{
	////////////////////////////////////////////////////////////////
	// initialize decrypt context if necessary

	if (!mCtx->decrypt) {
		mCtx->decrypt = EVP_PKEY_CTX_new(mKey.getOpenSSLKey()->key, NULL);
		if (!mCtx->decrypt) {
			MIRA_THROW(XSystemCall, "EVP_PKEY_CTX_new failed:" <<
			           OpenSSLErrorString::instance().err2str(ERR_get_error()));
		}

		if (EVP_PKEY_decrypt_init(mCtx->decrypt) <= 0) {
			auto errNo = ERR_get_error();
			EVP_PKEY_CTX_free(mCtx->decrypt);
			mCtx->decrypt = NULL;
			MIRA_THROW(XSystemCall, "EVP_PKEY_decrypt_init failed: " <<
			           OpenSSLErrorString::instance().err2str(errNo));
		}

		if (EVP_PKEY_CTX_set_rsa_padding(mCtx->decrypt, RSA_PKCS1_PADDING) <= 0) {
			auto errNo = ERR_get_error();
			EVP_PKEY_CTX_free(mCtx->decrypt);
			mCtx->decrypt = NULL;
			MIRA_THROW(XSystemCall, "EVP_PKEY_CTX_set_rsa_padding failed: " <<
			           OpenSSLErrorString::instance().err2str(errNo));
		}
	}

	////////////////////////////////////////////////////////////////
	// store new incoming encrypted data into mInputBuffer

	if (iSrcEnd != ioSrcBegin) {
		size_t inputBlockSize = iSrcEnd - ioSrcBegin;
		mInputBuffer.sputn(ioSrcBegin, inputBlockSize);
		ioSrcBegin = iSrcEnd;
	}

	////////////////////////////////////////////////////////////////
	// if the input buffer has enough data we can decrypt the data

	if (mInputBuffer.size() >= mRSASize) {

		const unsigned char* inDataPtr = (const unsigned char*)
			&(*(boost::asio::buffers_begin(mInputBuffer.data())));

		size_t outSize = mRSASize;
		if (EVP_PKEY_decrypt(mCtx->decrypt,
		                     mWrkBuffer, &outSize,
		                     inDataPtr, mRSASize) <= 0)
		{
			auto errNo = ERR_get_error();
			MIRA_THROW(XSystemCall, "EVP_PKEY_decrypt failed: " <<
			           OpenSSLErrorString::instance().err2str(errNo));
		}

		// remove encrypted data from input buffer
		mInputBuffer.consume(mRSASize);

		// append decrypted data to output buffer
		mOutputBuffer.sputn((const char*)mWrkBuffer, outSize);
	}

	///////////////////////////////////////////////////////////////////////////
	// output as much data as possible

	size_t spaceInDestBuffer = iDestEnd - ioDestBegin;
	size_t copySize = std::min(mOutputBuffer.size(), spaceInDestBuffer);

	const char *outBufferPtr =
		&(*(boost::asio::buffers_begin(mOutputBuffer.data())));

	memcpy(ioDestBegin, outBufferPtr, copySize);

	mOutputBuffer.consume(copySize);
	ioDestBegin += copySize;

	bool callAgain = (mOutputBuffer.size() > 0) || (mInputBuffer.size() > 0);
	return callAgain;
}

void RSAFilterBase::reset()
{
	mInputBuffer.consume(mInputBuffer.size());
	mOutputBuffer.consume(mOutputBuffer.size());
}

void RSAFilterBase::initFilter(const RSAKey& key, bool encrypt, void* alloc)
{
	mKey = key;

	// We use PKCS#1 v1.5 padding (RSA_PKCS1_PADDING). This means that values
	// being encrypted must be less than the size of the modulus (=mRSASize)
	// in bytes minus 11 bytes long.
	mRSASize = EVP_PKEY_get_size(mKey.getOpenSSLKey()->key);
	mBlockSize = mRSASize - 11; // for RSA_PKCS1_PADDING

	if (mWrkBuffer != NULL)
		delete [] mWrkBuffer;
	mWrkBuffer = new uint8[mRSASize];
}

///////////////////////////////////////////////////////////////////////////////

} // end of namespace Private

///////////////////////////////////////////////////////////////////////////////

} // namespace
