/*
 * Copyright (C) 2012 by
 *   MetraLabs GmbH (MLAB), GERMANY
 * and
 *   Neuroinformatics and Cognitive Robotics Labs (NICR) at TU Ilmenau, GERMANY
 * All rights reserved.
 *
 * Contact: info@mira-project.org
 *
 * Commercial Usage:
 *   Licensees holding valid commercial licenses may use this file in
 *   accordance with the commercial license agreement provided with the
 *   software or, alternatively, in accordance with the terms contained in
 *   a written agreement between you and MLAB or NICR.
 *
 * GNU General Public License Usage:
 *   Alternatively, this file may be used under the terms of the GNU
 *   General Public License version 3.0 as published by the Free Software
 *   Foundation and appearing in the file LICENSE.GPL3 included in the
 *   packaging of this file. Please review the following information to
 *   ensure the GNU General Public License version 3.0 requirements will be
 *   met: http://www.gnu.org/copyleft/gpl.html.
 *   Alternatively you may (at your option) use any later version of the GNU
 *   General Public License if such license has been publicly approved by
 *   MLAB and NICR (or its successors, if any).
 *
 * IN NO EVENT SHALL "MLAB" OR "NICR" BE LIABLE TO ANY PARTY FOR DIRECT,
 * INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF
 * THE USE OF THIS SOFTWARE AND ITS DOCUMENTATION, EVEN IF "MLAB" OR
 * "NICR" HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 * "MLAB" AND "NICR" SPECIFICALLY DISCLAIM ANY WARRANTIES, INCLUDING,
 * BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
 * FITNESS FOR A PARTICULAR PURPOSE. THE SOFTWARE PROVIDED HEREUNDER IS
 * ON AN "AS IS" BASIS, AND "MLAB" AND "NICR" HAVE NO OBLIGATION TO
 * PROVIDE MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS OR MODIFICATIONS.
 */

/**
 * @file AESFilter.h
 *    Classes for a AES encryption/decryption based on boost::iostreams filter.
 *
 * @author Christian Martin
 * @date   2012/06/17
 */

#ifndef _MIRA_AESFILTER_H_
#define _MIRA_AESFILTER_H_

#ifndef Q_MOC_RUN
#include <boost/iostreams/concepts.hpp>
#include <boost/iostreams/filter/symmetric.hpp>
#include <boost/asio/basic_streambuf.hpp>
#endif

#include <platform/Types.h>
#include <error/Exceptions.h>

namespace mira {

///////////////////////////////////////////////////////////////////////////////

/// The supported AES bit lengths.
enum AESBitLength {
	AES_128 = 1,
	AES_192,
	AES_256
};

/// The supported block cipher modes for the AES encryption/decryption.
enum AESBlockCipherMode
{
	AES_ECB = 1,  /**< Electronic codebook */
	AES_CBC,      /**< Cipher-block chaining */
	AES_CFB1,     /**< Cipher feedback (1 feedback bit) */
	AES_CFB8,     /**< Cipher feedback (8 feedback bits) */
	AES_CFB128,   /**< Cipher feedback (128 feedback bits) */
	AES_OFB       /**< Output feedback */
};

/// The supported message digest algorithm for key generation.
enum AESMessageDigestAlgo
{
	AES_MD_MD5,    /**< Use MD5 as message digest. */
	AES_MD_SHA1,   /**< Use SHA1 as message digest. */
	AES_MD_SHA256, /**< Use SHA256 as message digest. */
	AES_MD_SHA512  /**< Use SHA512 as message digest. */
};

/// The AES configuration for encryption and decryption.
struct AESConfiguration
{
	/// The AES key bit length.
	AESBitLength bitLength;

	/// The used block cipher mode.
	AESBlockCipherMode blockCipherMode;

	/// The message digest algorithm for key generation.
	AESMessageDigestAlgo mdAlgo;

	/// Number of times the key material is hashed. Default is 5. More rounds
	/// are more secure but slower.
	uint16 nrRounds;

	/// The salt data. Must have length 8 or zero.
	std::string salt;

	AESConfiguration() :
		bitLength(AES_256),
		blockCipherMode(AES_CBC),
		mdAlgo(AES_MD_SHA1),
		nrRounds(5)
	{}

	template<typename Reflector>
	void reflect(Reflector& r)
	{
		r.member("BitLength", bitLength,
		         "The AES key bit length.");
		r.member("BlockCipherMode", blockCipherMode,
		         "The used block cipher mode.");
		r.member("MessageDigestAlgorithm", mdAlgo,
		         "The message digest algorithm for key generation.");
		r.member("NrRounds", nrRounds,
		         "Number of times the key material is hashed.");
		r.member("Salt", salt, "The salt data.");
	}
};

///////////////////////////////////////////////////////////////////////////////

///@cond INTERNAL

namespace Private {

///////////////////////////////////////////////////////////////////////////////
// An allocator trait for the boost iostream filter

template<typename Alloc>
struct AESFilterAllocatorTraits
{
#ifndef BOOST_NO_STD_ALLOCATOR
	typedef typename Alloc::template rebind<char>::other type;
#else
	typedef std::allocator<char> type;
#endif
};

///////////////////////////////////////////////////////////////////////////////
// An allocator for the boost iostreams filter

template< typename Alloc,
          typename Base = BOOST_DEDUCED_TYPENAME AESFilterAllocatorTraits<Alloc>::type >
struct AESFilterAllocator :
	private Base
{
private:
	typedef typename Base::size_type size_type;
public:
	BOOST_STATIC_CONSTANT(bool, custom =
	                     (!boost::is_same<std::allocator<char>, Base>::value));
	typedef typename AESFilterAllocatorTraits<Alloc>::type allocator_type;

	static void* allocate(void* self, uint32 items, uint32 size);
	static void deallocate(void* self, void* address);
};

///////////////////////////////////////////////////////////////////////////////
// An internal class of the AES filter class

class AESFilterBase
{
public:
	typedef char char_type;

protected:
	/// The constructor.
	AESFilterBase();

	/// The destructor
	~AESFilterBase();

	template<typename Alloc>
	void init(const AESConfiguration& cfg, const std::string& key,
	          bool encrypt, AESFilterAllocator<Alloc>& alloc)
	{
		initFilter(cfg, key, encrypt, &alloc);
	}

	/// AES encryption
	bool encrypt(const char*& ioSrcBegin,
	             const char*  oSrcEnd,
	             char*&       ioDestBegin,
	             char*        iDestEnd,
	             bool         iFlush);

	/// AES decryption
	bool decrypt(const char*& ioSrcBegin,
	             const char*  iSrcEnd,
	             char*&       ioDestBegin,
	             char*        iDestEnd,
	             bool         iFlush);

	/// reset all buffers
	void reset();

private:
	void initFilter(const AESConfiguration& cfg, const std::string& key,
	                bool encrypt, void* alloc);

private:
	struct Context;
	Context* mCtx;

private:
	AESConfiguration mCfg;

	bool mEncryptMode;
	bool mFinalCalled;

	boost::asio::basic_streambuf<> mInputBuffer;
	boost::asio::basic_streambuf<> mOutputBuffer;

	uint8* mWrkBuffer;
	size_t mWrkBufferSize;
};

///////////////////////////////////////////////////////////////////////////////
// Template name: AESEncryptionImpl

template<typename Alloc = std::allocator<char> >
class AESEncryptionImpl :
	public AESFilterBase,
	public AESFilterAllocator<Alloc>
{
public:
	AESEncryptionImpl(const AESConfiguration& cfg, const std::string& key);
	~AESEncryptionImpl();

	bool filter(const char* &ioSrcBegin, const char* iSrcEnd,
	            char* &ioDestBegin, char* iDestEnd, bool iFlush);

	void close();
};

///////////////////////////////////////////////////////////////////////////////
// Template name: AESDecryptionImpl

template<typename Alloc = std::allocator<char> >
class AESDecryptionImpl :
	public AESFilterBase,
	public AESFilterAllocator<Alloc>
{
public:
	AESDecryptionImpl(const AESConfiguration& cfg, const std::string& key);
	~AESDecryptionImpl();

	bool filter(const char* &ioSrcBegin, const char* iSrcEnd,
	            char* &ioDestBegin, char* iDestEnd, bool iFlush);

	void close();
};

///////////////////////////////////////////////////////////////////////////////

} // end of namespace Private

///////////////////////////////////////////////////////////////////////////////
// BasicAESEncryptionFilter

template<typename Alloc = std::allocator<char> >
struct BasicAESEncryptionFilter :
	boost::iostreams::symmetric_filter<Private::AESEncryptionImpl<Alloc>, Alloc>
{
private:
	typedef Private::AESEncryptionImpl<Alloc>                     impl_type;
	typedef boost::iostreams::symmetric_filter<impl_type, Alloc>  base_type;

public:
	typedef typename base_type::char_type char_type;
	typedef typename base_type::category  category;

	BasicAESEncryptionFilter(const AESConfiguration& cfg,
	                         const std::string& key,
	                         int bufferSize = boost::iostreams::default_device_buffer_size);
};
BOOST_IOSTREAMS_PIPABLE(BasicAESEncryptionFilter, 1)

///////////////////////////////////////////////////////////////////////////////
// BasicAESDecryptionFilter

template<typename Alloc = std::allocator<char> >
struct BasicAESDecryptionFilter :
	 boost::iostreams::symmetric_filter<Private::AESDecryptionImpl<Alloc>, Alloc>
{
private:
	typedef Private::AESDecryptionImpl<Alloc>                     impl_type;
	typedef boost::iostreams::symmetric_filter<impl_type, Alloc>  base_type;

public:
	typedef typename base_type::char_type char_type;
	typedef typename base_type::category  category;

	BasicAESDecryptionFilter(const AESConfiguration& cfg,
	                         const std::string& key,
	                         int bufferSize = boost::iostreams::default_device_buffer_size);
};
BOOST_IOSTREAMS_PIPABLE(BasicAESDecryptionFilter, 1)

///////////////////////////////////////////////////////////////////////////////

///@endcond

///////////////////////////////////////////////////////////////////////////////

/**
 * @brief A AES public encryption filter for boost::iostreams.
 *
 * Usage example:
 * \code
 *
 *     AESConfiguration cfg;
 *     cfg.bitLength = AES_256;
 *     cfg.blockCipherMode = AES_CBC;
 *     cfg.salt = "12345678";
 *
 *     boost::iostreams::filtering_ostream tOut;
 *     tOut.push(AESEncryptionFilter(cfg, "Password"));
 *     tOut.push(boost::iostreams::file_sink("AES.dat",
 *                                           std::ios::out | std::ios::binary));
 *     tOut << "Hello_world!" << endl;
 *
 * \endcode
 *
 * @ingroup SecurityModule
 */
typedef BasicAESEncryptionFilter<> AESEncryptionFilter;

/**
 * @brief A AES private decryption filter for boost::iostreams.
 *
 * Usage example:
 * \code
 *
 *     AESConfiguration cfg;
 *     cfg.bitLength = AES_256;
 *     cfg.blockCipherMode = AES_CBC;
 *     cfg.salt = "12345678";
 *
 *     string tMsg;
 *
 *     boost::iostreams::filtering_istream tIn;
 *     tIn.push(AESDecryptionFilter(cfg, "Password"));
 *     tIn.push(boost::iostreams::file_source("AES.dat",
 *                                            std::ios::in | std::ios::binary));
 *     tIn >> tMsg;
 *
 * \endcode
 *
 * @ingroup SecurityModule
 */
typedef BasicAESDecryptionFilter<> AESDecryptionFilter;

///////////////////////////////////////////////////////////////////////////////
// Template implementation
///////////////////////////////////////////////////////////////////////////////

///@cond INTERNAL

namespace Private {

///////////////////////////////////////////////////////////////////////////////
// Implementation of template AES_allocator

template<typename Alloc, typename Base>
void* AESFilterAllocator<Alloc, Base>::allocate(void* self, uint32 items,
                                                uint32 size)
{
	size_type len = items * size;
	char* ptr =
		static_cast<allocator_type*>(self)->allocate
			(len + sizeof(size_type)
			#if BOOST_WORKAROUND(BOOST_DINKUMWARE_STDLIB, == 1)
				, (char*)0
			#endif
			);
	*reinterpret_cast<size_type*>(ptr) = len;
	return ptr + sizeof(size_type);
}

template<typename Alloc, typename Base>
void AESFilterAllocator<Alloc, Base>::deallocate(void* self, void* address)
{
	char* ptr = reinterpret_cast<char*>(address) - sizeof(size_type);
	size_type len = *reinterpret_cast<size_type*>(ptr) + sizeof(size_type);
	static_cast<allocator_type*>(self)->deallocate(ptr, len);
}

///////////////////////////////////////////////////////////////////////////////
// Implementation of AESEncryptionImpl

template<typename Alloc>
AESEncryptionImpl<Alloc>::AESEncryptionImpl(const AESConfiguration& cfg,
                                            const std::string& key)
{
	init(cfg, key, true, static_cast<AESFilterAllocator<Alloc>&>(*this));
}

template<typename Alloc>
AESEncryptionImpl<Alloc>::~AESEncryptionImpl()
{
	reset();
}

template<typename Alloc>
bool AESEncryptionImpl<Alloc>::filter(const char* &ioSrcBegin,
                                      const char* iSrcEnd,
                                      char* &ioDestBegin,
                                      char* iDestEnd, bool iFlush)
{
	return(encrypt(ioSrcBegin, iSrcEnd, ioDestBegin, iDestEnd, iFlush));
}

template<typename Alloc>
void AESEncryptionImpl<Alloc>::close()
{
	reset();
}

///////////////////////////////////////////////////////////////////////////////
// Implementation of AESDecryptionImpl

template<typename Alloc>
AESDecryptionImpl<Alloc>::AESDecryptionImpl(const AESConfiguration& cfg,
                                            const std::string& key)
{
	init(cfg, key, false, static_cast<AESFilterAllocator<Alloc>&>(*this));
}

template<typename Alloc>
AESDecryptionImpl<Alloc>::~AESDecryptionImpl()
{
	reset();
}

template<typename Alloc>
bool AESDecryptionImpl<Alloc>::filter(const char* &ioSrcBegin,
                                      const char* iSrcEnd,
                                      char* &ioDestBegin,
                                      char* iDestEnd, bool iFlush)
{
	return(decrypt(ioSrcBegin, iSrcEnd, ioDestBegin, iDestEnd, iFlush));
}

template<typename Alloc>
void AESDecryptionImpl<Alloc>::close()
{
	reset();
}

///////////////////////////////////////////////////////////////////////////////

} // end of namespace Private

///////////////////////////////////////////////////////////////////////////////
// Implementation of BasicAESEncryptionFilter

template<typename Alloc>
BasicAESEncryptionFilter<Alloc>::BasicAESEncryptionFilter(
		const AESConfiguration& cfg, const std::string& key, int bufferSize) :
	base_type(bufferSize, cfg, key)
{
}

///////////////////////////////////////////////////////////////////////////////
// Implementation of BasicAESDecryptionFilter

template<typename Alloc>
BasicAESDecryptionFilter<Alloc>::BasicAESDecryptionFilter(
		const AESConfiguration& cfg, const std::string& key, int bufferSize) :
	base_type(bufferSize, cfg, key)
{
}

///////////////////////////////////////////////////////////////////////////////

///@endcond

///////////////////////////////////////////////////////////////////////////////

} // namespaces

#endif
