/*
 * Copyright (C) 2012 by
 *   MetraLabs GmbH (MLAB), GERMANY
 * and
 *   Neuroinformatics and Cognitive Robotics Labs (NICR) at TU Ilmenau, GERMANY
 * All rights reserved.
 *
 * Contact: info@mira-project.org
 *
 * Commercial Usage:
 *   Licensees holding valid commercial licenses may use this file in
 *   accordance with the commercial license agreement provided with the
 *   software or, alternatively, in accordance with the terms contained in
 *   a written agreement between you and MLAB or NICR.
 *
 * GNU General Public License Usage:
 *   Alternatively, this file may be used under the terms of the GNU
 *   General Public License version 3.0 as published by the Free Software
 *   Foundation and appearing in the file LICENSE.GPL3 included in the
 *   packaging of this file. Please review the following information to
 *   ensure the GNU General Public License version 3.0 requirements will be
 *   met: http://www.gnu.org/copyleft/gpl.html.
 *   Alternatively you may (at your option) use any later version of the GNU
 *   General Public License if such license has been publicly approved by
 *   MLAB and NICR (or its successors, if any).
 *
 * IN NO EVENT SHALL "MLAB" OR "NICR" BE LIABLE TO ANY PARTY FOR DIRECT,
 * INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF
 * THE USE OF THIS SOFTWARE AND ITS DOCUMENTATION, EVEN IF "MLAB" OR
 * "NICR" HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 * "MLAB" AND "NICR" SPECIFICALLY DISCLAIM ANY WARRANTIES, INCLUDING,
 * BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
 * FITNESS FOR A PARTICULAR PURPOSE. THE SOFTWARE PROVIDED HEREUNDER IS
 * ON AN "AS IS" BASIS, AND "MLAB" AND "NICR" HAVE NO OBLIGATION TO
 * PROVIDE MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS OR MODIFICATIONS.
 */

/**
 * @file RSAFilter.h
 *    A class for a RSA encryption based on boost::iostreams filter.
 *
 * @author Christian Martin
 * @date   2010/09/20
 */

#ifndef _MIRA_RSAFILTER_H_
#define _MIRA_RSAFILTER_H_

#ifndef Q_MOC_RUN
#include <boost/iostreams/concepts.hpp>
#include <boost/iostreams/filter/symmetric.hpp>
#include <boost/asio/basic_streambuf.hpp>
#endif

#include <platform/Types.h>
#include <error/Exceptions.h>
#include <security/RSAKey.h>

namespace mira {

///////////////////////////////////////////////////////////////////////////////

///@cond INTERNAL

namespace Private {

///////////////////////////////////////////////////////////////////////////////
// An allocator trait for the boost iostream filter

template<typename Alloc>
struct RSAFilterAllocatorTraits
{
#ifndef BOOST_NO_STD_ALLOCATOR
	typedef typename Alloc::template rebind<char>::other type;
#else
	typedef std::allocator<char> type;
#endif
};

///////////////////////////////////////////////////////////////////////////////
// An allocator for the boost iostreams filter

template< typename Alloc,
          typename Base = BOOST_DEDUCED_TYPENAME RSAFilterAllocatorTraits<Alloc>::type >
struct RSAFilterAllocator :
	private Base
{
private:
	typedef typename Base::size_type size_type;
public:
	BOOST_STATIC_CONSTANT(bool, custom =
	                     (!boost::is_same<std::allocator<char>, Base>::value));
	typedef typename RSAFilterAllocatorTraits<Alloc>::type allocator_type;

	static void* allocate(void* self, uint32 items, uint32 size);
	static void deallocate(void* self, void* address);
};

///////////////////////////////////////////////////////////////////////////////
// An internal class of the RSA filter class

class MIRA_BASE_EXPORT RSAFilterBase
{
public:
	typedef char char_type;

protected:
	/// The constructor.
	RSAFilterBase();

	/// The destructor
	~RSAFilterBase();

	template<typename Alloc>
	void init(const RSAKey& key, bool encrypt, RSAFilterAllocator<Alloc>& alloc)
	{
		initFilter(key, encrypt, &alloc);
	}

	/// Public encryption
	bool encryptPublic(const char*& ioSrcBegin,
	                   const char*  oSrcEnd,
	                   char*&       ioDestBegin,
	                   char*        iDestEnd,
	                   bool         iFlush);

	/// Private decryption
	bool decryptPrivate(const char*& ioSrcBegin,
	                    const char*  iSrcEnd,
	                    char*&       ioDestBegin,
	                    char*        iDestEnd,
	                    bool         iFlush);

	/// reset all buffers
	void reset();

private:
	void initFilter(const RSAKey& key, bool encrypt, void* alloc);

private:
	RSAKey mKey;

	boost::asio::basic_streambuf<> mInputBuffer;
	boost::asio::basic_streambuf<> mOutputBuffer;

	uint8* mWrkBuffer;
	size_t mRSASize;
	size_t mBlockSize;

#ifdef MIRA_USE_OPENSSL3
	struct KeyCtx;
	KeyCtx* mCtx;
#endif
};

///////////////////////////////////////////////////////////////////////////////
// Template name: RSAPublicEncryptionImpl

template<typename Alloc = std::allocator<char> >
class RSAPublicEncryptionImpl :
	public RSAFilterBase,
	public RSAFilterAllocator<Alloc>
{
public:
	RSAPublicEncryptionImpl(const RSAKey& key);
	~RSAPublicEncryptionImpl();

	bool filter(const char* &ioSrcBegin, const char* iSrcEnd,
	            char* &ioDestBegin, char* iDestEnd, bool iFlush);

	void close();
};

///////////////////////////////////////////////////////////////////////////////
// Template name: RSAPrivateDecryptionImpl

template<typename Alloc = std::allocator<char> >
class RSAPrivateDecryptionImpl :
	public RSAFilterBase,
	public RSAFilterAllocator<Alloc>
{
public:
	RSAPrivateDecryptionImpl(const RSAKey& key);
	~RSAPrivateDecryptionImpl();

	bool filter(const char* &ioSrcBegin, const char* iSrcEnd,
	            char* &ioDestBegin, char* iDestEnd, bool iFlush);

	void close();
};

///////////////////////////////////////////////////////////////////////////////

} // end of namespace Private

///////////////////////////////////////////////////////////////////////////////
// BasicRSAPublicEncryptionFilter

template<typename Alloc = std::allocator<char> >
struct BasicRSAPublicEncryptionFilter :
	boost::iostreams::symmetric_filter<Private::RSAPublicEncryptionImpl<Alloc>, Alloc> 
{
private:
	typedef Private::RSAPublicEncryptionImpl<Alloc>               impl_type;
	typedef boost::iostreams::symmetric_filter<impl_type, Alloc>  base_type;

public:
	typedef typename base_type::char_type char_type;
	typedef typename base_type::category  category;

	BasicRSAPublicEncryptionFilter(const RSAKey& key,
	                               int bufferSize = boost::iostreams::default_device_buffer_size);
};
BOOST_IOSTREAMS_PIPABLE(BasicRSAPublicEncryptionFilter, 1)

///////////////////////////////////////////////////////////////////////////////
// BasicRSAPrivateDecryptionFilter

template<typename Alloc = std::allocator<char> >
struct BasicRSAPrivateDecryptionFilter :
	 boost::iostreams::symmetric_filter<Private::RSAPrivateDecryptionImpl<Alloc>, Alloc>
{
private:
	typedef Private::RSAPrivateDecryptionImpl<Alloc>              impl_type;
	typedef boost::iostreams::symmetric_filter<impl_type, Alloc>  base_type;

public:
	typedef typename base_type::char_type char_type;
	typedef typename base_type::category  category;

	BasicRSAPrivateDecryptionFilter(const RSAKey& key,
	                                int bufferSize = boost::iostreams::default_device_buffer_size);
};
BOOST_IOSTREAMS_PIPABLE(BasicRSAPrivateDecryptionFilter, 1)

///////////////////////////////////////////////////////////////////////////////

///@endcond

///////////////////////////////////////////////////////////////////////////////

/**
 * @brief A RSA public encryption filter for boost::iostreams.
 *
 * Usage examples:
 * \code
 *     RSAKey publicKey, privateKey;
 *     RSAKey::generateKey(2048, publicKey, privateKey);
 *
 *     boost::iostreams::filtering_ostream outStream;
 *     outStream.push(RSAPublicEncryptionFilter(publicKey));
 *     outStream.push(boost::iostreams::file_sink("rsa.dat", std::ios::out | std::ios::binary));
 *     outStream << "Hello_world!" << std::endl;
 * \endcode
 *
 * \code
 *     RSAKey publicKey, privateKey;
 *     RSAKey::generateKey(2048, publicKey, privateKey);
 *
 *     std::string inString = "Hello-world!";
 *     std::vector<char> buffer;
 *
 *     boost::iostreams::back_insert_device<std::vector<char>> sink{buffer};
 *     boost::iostreams::filtering_ostream outStream;
 *     outStream.push(RSAPublicEncryptionFilter(publicKey));
 *     outStream.push(sink);
 *     outStream << inString;
 *     outStream.pop();
 * \endcode
 *
 * @ingroup SecurityModule
 */
typedef BasicRSAPublicEncryptionFilter<> RSAPublicEncryptionFilter;

/**
 * @brief A RSA private decryption filter for boost::iostreams.
 *
 * Usage examples:
 * \code
 *     RSAKey publicKey, privateKey;
 *     RSAKey::generateKey(2048, publicKey, privateKey);
 *
 *     std::string inString;
 *     boost::iostreams::filtering_istream inStream;
 *     inStream.push(RSAPrivateDecryptionFilter(privateKey));
 *     inStream.push(boost::iostreams::file_source("rsa.dat", std::ios::in | std::ios::binary));
 *     inStream >> inString;
 * \endcode
 *
 * \code
 *     RSAKey publicKey, privateKey;
 *     RSAKey::generateKey(2048, publicKey, privateKey);
 *
 *     std::string outStr;
 *     boost::iostreams::array_source source{(char*)buffer.data(), buffer.size()};
 *     boost::iostreams::filtering_istream inStream;
 *     inStream.push(RSAPrivateDecryptionFilter(privateKey));
 *     inStream.push(source);
 *     inStream >> outStr;
 * \endcode
 *
 * @ingroup SecurityModule
 */
typedef BasicRSAPrivateDecryptionFilter<> RSAPrivateDecryptionFilter;

///////////////////////////////////////////////////////////////////////////////
// Template implementation
///////////////////////////////////////////////////////////////////////////////

///@cond INTERNAL

namespace Private {

///////////////////////////////////////////////////////////////////////////////
// Implementation of template rsa_allocator

template<typename Alloc, typename Base>
void* RSAFilterAllocator<Alloc, Base>::allocate(void* self, uint32 items,
                                                uint32 size)
{
	size_type len = items * size;
	char* ptr =
		static_cast<allocator_type*>(self)->allocate
			(len + sizeof(size_type)
			#if BOOST_WORKAROUND(BOOST_DINKUMWARE_STDLIB, == 1)
				, (char*)0
			#endif
			);
	*reinterpret_cast<size_type*>(ptr) = len;
	return ptr + sizeof(size_type);
}

template<typename Alloc, typename Base>
void RSAFilterAllocator<Alloc, Base>::deallocate(void* self, void* address)
{
	char* ptr = reinterpret_cast<char*>(address) - sizeof(size_type);
	size_type len = *reinterpret_cast<size_type*>(ptr) + sizeof(size_type);
	static_cast<allocator_type*>(self)->deallocate(ptr, len);
}

///////////////////////////////////////////////////////////////////////////////
// Implementation of RSAPublicEncryptionImpl

template<typename Alloc>
RSAPublicEncryptionImpl<Alloc>::RSAPublicEncryptionImpl(const RSAKey& key)
{
	init(key, true, static_cast<RSAFilterAllocator<Alloc>&>(*this));
}

template<typename Alloc>
RSAPublicEncryptionImpl<Alloc>::~RSAPublicEncryptionImpl()
{
	reset();
}

template<typename Alloc>
bool RSAPublicEncryptionImpl<Alloc>::filter(const char* &ioSrcBegin,
                                            const char* iSrcEnd,
                                            char* &ioDestBegin,
                                            char* iDestEnd, bool iFlush)
{
	return(encryptPublic(ioSrcBegin, iSrcEnd, ioDestBegin, iDestEnd, iFlush));
}

template<typename Alloc>
void RSAPublicEncryptionImpl<Alloc>::close()
{
	reset();
}

///////////////////////////////////////////////////////////////////////////////
// Implementation of RSAPrivateDecryptionImpl

template<typename Alloc>
RSAPrivateDecryptionImpl<Alloc>::RSAPrivateDecryptionImpl(const RSAKey& key)
{
	init(key, false, static_cast<RSAFilterAllocator<Alloc>&>(*this));
}

template<typename Alloc>
RSAPrivateDecryptionImpl<Alloc>::~RSAPrivateDecryptionImpl()
{
	reset();
}

template<typename Alloc>
bool RSAPrivateDecryptionImpl<Alloc>::filter(const char* &ioSrcBegin,
                                             const char* iSrcEnd,
                                             char* &ioDestBegin,
                                             char* iDestEnd, bool iFlush)
{
	return(decryptPrivate(ioSrcBegin, iSrcEnd, ioDestBegin, iDestEnd, iFlush));
}

template<typename Alloc>
void RSAPrivateDecryptionImpl<Alloc>::close()
{
	reset();
}

///////////////////////////////////////////////////////////////////////////////

} // end of namespace Private

///////////////////////////////////////////////////////////////////////////////
// Implementation of BasicRSAPublicEncryptionFilter

template<typename Alloc>
BasicRSAPublicEncryptionFilter<Alloc>::BasicRSAPublicEncryptionFilter(
		const RSAKey& key, int bufferSize) :
	base_type(bufferSize, key)
{
	if (!key.isPublicKey())
		MIRA_THROW(XInvalidConfig, "The key is not a public key.");
}

///////////////////////////////////////////////////////////////////////////////
// Implementation of BasicRSAPrivateDecryptionFilter

template<typename Alloc>
BasicRSAPrivateDecryptionFilter<Alloc>::BasicRSAPrivateDecryptionFilter(
		const RSAKey& key, int bufferSize) :
	base_type(bufferSize, key)
{
	if (!key.isPrivateKey())
		MIRA_THROW(XInvalidConfig, "The key is not a private key.");
}

///////////////////////////////////////////////////////////////////////////////

///@endcond

///////////////////////////////////////////////////////////////////////////////

} // namespaces

#endif
