/*
 * Copyright (C) 2012 by
 *   MetraLabs GmbH (MLAB), GERMANY
 * and
 *   Neuroinformatics and Cognitive Robotics Labs (NICR) at TU Ilmenau, GERMANY
 * All rights reserved.
 *
 * Contact: info@mira-project.org
 *
 * Commercial Usage:
 *   Licensees holding valid commercial licenses may use this file in
 *   accordance with the commercial license agreement provided with the
 *   software or, alternatively, in accordance with the terms contained in
 *   a written agreement between you and MLAB or NICR.
 *
 * GNU General Public License Usage:
 *   Alternatively, this file may be used under the terms of the GNU
 *   General Public License version 3.0 as published by the Free Software
 *   Foundation and appearing in the file LICENSE.GPL3 included in the
 *   packaging of this file. Please review the following information to
 *   ensure the GNU General Public License version 3.0 requirements will be
 *   met: http://www.gnu.org/copyleft/gpl.html.
 *   Alternatively you may (at your option) use any later version of the GNU
 *   General Public License if such license has been publicly approved by
 *   MLAB and NICR (or its successors, if any).
 *
 * IN NO EVENT SHALL "MLAB" OR "NICR" BE LIABLE TO ANY PARTY FOR DIRECT,
 * INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF
 * THE USE OF THIS SOFTWARE AND ITS DOCUMENTATION, EVEN IF "MLAB" OR
 * "NICR" HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 * "MLAB" AND "NICR" SPECIFICALLY DISCLAIM ANY WARRANTIES, INCLUDING,
 * BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
 * FITNESS FOR A PARTICULAR PURPOSE. THE SOFTWARE PROVIDED HEREUNDER IS
 * ON AN "AS IS" BASIS, AND "MLAB" AND "NICR" HAVE NO OBLIGATION TO
 * PROVIDE MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS OR MODIFICATIONS.
 */

/**
 * @file RSAFilter.C
 *    Implementation of RSAFilter.h
 *
 * @author Christian Martin
 * @date   2010/09/20
 */

#include <security/RSAFilter.h>

#include <boost/asio/buffers_iterator.hpp>

#include <error/Exceptions.h>

#include "OpenSSLHelper.h"

using namespace std;

namespace mira {

///////////////////////////////////////////////////////////////////////////////

namespace Private {

///////////////////////////////////////////////////////////////////////////////

RSAFilterBase::RSAFilterBase() :
	mWrkBuffer(NULL)
{
}

RSAFilterBase::~RSAFilterBase()
{
	delete [] mWrkBuffer;
	mWrkBuffer = NULL;
}

bool RSAFilterBase::encryptPublic(const char*& ioSrcBegin,
                                  const char*  iSrcEnd,
                                  char*&       ioDestBegin,
                                  char*        iDestEnd,
                                  bool         iFlush)
{
	///////////////////////////////////////////////////////////////////////////////
	// store new incoming encrypted data into mInputBuffer

	if (iSrcEnd != ioSrcBegin) {
		size_t tInputBlockSize = iSrcEnd - ioSrcBegin;
		mInputBuffer.sputn(ioSrcBegin, tInputBlockSize);
		ioSrcBegin = iSrcEnd;
	}

	///////////////////////////////////////////////////////////////////////////
	// if the input buffer has enough data or flush is true encrypt the data

	if (iFlush || (mInputBuffer.size() >= mBlockSize)) {
		// Feed the random number generator
		RSAKey::feedRandomNumberGenerator(8*mRSASize);

		size_t tBlockLen = mInputBuffer.size();
		if (tBlockLen > mBlockSize)
			tBlockLen = mBlockSize;

		const char *tDataPlain =
			&(*(boost::asio::buffers_begin(mInputBuffer.data())));
		int tSize = RSA_public_encrypt(tBlockLen,
		                               (unsigned char*)tDataPlain,
		                               mWrkBuffer,
		                               mKey.getOpenSSLKey()->key, // public key
		                               RSA_PKCS1_PADDING);
		if (tSize != (int)mRSASize) {
			unsigned long tErrNo = ERR_get_error();
			MIRA_THROW(XSystemCall, "RSA_public_encrypt failed: " << 
			           OpenSSLErrorString::instance().err2str(tErrNo));
		}

		// remove data from input buffer
		mInputBuffer.consume(tBlockLen);

		// append data to output buffer
		mOutputBuffer.sputn((const char*)mWrkBuffer, tSize);
	}

	///////////////////////////////////////////////////////////////////////////
	// output as much data as possible

	size_t tSpaceInBuffer = iDestEnd - ioDestBegin;
	const char *tDataPlain =
		&(*(boost::asio::buffers_begin(mOutputBuffer.data())));

	size_t tCopySize = mOutputBuffer.size();
	if (tCopySize > tSpaceInBuffer)
		tCopySize = tSpaceInBuffer;

	memcpy(ioDestBegin, tDataPlain, tCopySize);

	mOutputBuffer.consume(tCopySize);
	ioDestBegin += tCopySize;

	bool tCallAgain = (mOutputBuffer.size() > 0) || (mInputBuffer.size() > 0);
	return(tCallAgain);
}

bool RSAFilterBase::decryptPrivate(const char*& ioSrcBegin,
                                   const char*  iSrcEnd,
                                   char*&       ioDestBegin,
                                   char*        iDestEnd,
                                   bool         iFlush)
{
	///////////////////////////////////////////////////////////////////////////
	// store new incoming encrypted data into mInputBuffer

	if (iSrcEnd != ioSrcBegin) {
		size_t tInputBlockSize = iSrcEnd - ioSrcBegin;
		mInputBuffer.sputn(ioSrcBegin, tInputBlockSize);
		ioSrcBegin = iSrcEnd;
	}

	///////////////////////////////////////////////////////////////////////////
	// if the input buffer has enough data or flush is true encrypt the data

	if (mInputBuffer.size() >= mRSASize) {
		const char *tDataPlain =
			&(*(boost::asio::buffers_begin(mInputBuffer.data())));
		int tSize = RSA_private_decrypt(mRSASize,
		                                (unsigned char*)tDataPlain,
		                                mWrkBuffer,
		                                mKey.getOpenSSLKey()->key, // private
		                                RSA_PKCS1_PADDING);
		if (tSize < 0) {
			unsigned long tErrNo = ERR_get_error();
			MIRA_THROW(XSystemCall, "RSA_private_decrypt failed: " << 
			           OpenSSLErrorString::instance().err2str(tErrNo));
		}

		// remove data from input buffer
		mInputBuffer.consume(mRSASize);

		// append data to output buffer
		mOutputBuffer.sputn((const char*)mWrkBuffer, tSize);
	}

	///////////////////////////////////////////////////////////////////////////
	// output as much data as possible

	size_t tSpaceInBuffer = iDestEnd - ioDestBegin;
	const char *tDataPlain =
		&(*(boost::asio::buffers_begin(mOutputBuffer.data())));

	size_t tCopySize = mOutputBuffer.size();
	if (tCopySize > tSpaceInBuffer)
		tCopySize = tSpaceInBuffer;

	memcpy(ioDestBegin, tDataPlain, tCopySize);

	mOutputBuffer.consume(tCopySize);
	ioDestBegin += tCopySize;

	bool tCallAgain = (mOutputBuffer.size() > 0) || (mInputBuffer.size() > 0);
	return(tCallAgain);
}

void RSAFilterBase::reset()
{
	mInputBuffer.consume(mInputBuffer.size());
	mOutputBuffer.consume(mOutputBuffer.size());
}

void RSAFilterBase::initFilter(const RSAKey& key, bool encrypt, void* alloc)
{
	mKey = key;

	// We use PKCS#1 v1.5 padding (RSA_PKCS1_PADDING). This means that values
	// being encrypted must be less than the size of the modulus (=mRSASize)
	// in bytes minus 10 bytes long.
	mRSASize   = RSA_size(mKey.getOpenSSLKey()->key);
	mBlockSize = mRSASize - 11; // for RSA_PKCS1_PADDING

	if (mWrkBuffer != NULL)
		delete [] mWrkBuffer;
	mWrkBuffer = new uint8[mRSASize];
}

///////////////////////////////////////////////////////////////////////////////

} // end of namespace Private

///////////////////////////////////////////////////////////////////////////////

} // namespace
