#ifndef MS_UTILS_HPP
#define MS_UTILS_HPP

#include "common.hpp"
#include "RTC/Consts.hpp"
#include <openssl/evp.h>
#include <cmath>
#include <cstring> // std::memcmp(), std::memcpy()
#include <limits>  // std::numeric_limits
#include <string>
#include <type_traits> // std::enable_if, std::is_same_v
#ifdef _WIN32
#include <ws2ipdef.h>
// https://stackoverflow.com/a/24550632/2085408
#include <intrin.h>
#define __builtin_popcount __popcnt
#endif

namespace Utils
{
	class IP
	{
	public:
		static int GetFamily(const std::string& ip);

		static void GetAddressInfo(const struct sockaddr* addr, int& family, std::string& ip, uint16_t& port);

		static size_t GetAddressLen(const struct sockaddr* addr);

		static bool CompareAddresses(const struct sockaddr* addr1, const struct sockaddr* addr2)
		{
			// Compare family.
			if (
			  addr1->sa_family != addr2->sa_family ||
			  (addr1->sa_family != AF_INET && addr1->sa_family != AF_INET6) ||
			  (addr2->sa_family != AF_INET && addr2->sa_family != AF_INET6))
			{
				return false;
			}

			// Compare port.
			if (
			  reinterpret_cast<const struct sockaddr_in*>(addr1)->sin_port !=
			  reinterpret_cast<const struct sockaddr_in*>(addr2)->sin_port)
			{
				return false;
			}

			// Compare IP.
			switch (addr1->sa_family)
			{
				case AF_INET:
				{
					return (
					  reinterpret_cast<const struct sockaddr_in*>(addr1)->sin_addr.s_addr ==
					  reinterpret_cast<const struct sockaddr_in*>(addr2)->sin_addr.s_addr);
				}

				case AF_INET6:
				{
					return (
					  std::memcmp(
					    std::addressof(reinterpret_cast<const struct sockaddr_in6*>(addr1)->sin6_addr),
					    std::addressof(reinterpret_cast<const struct sockaddr_in6*>(addr2)->sin6_addr),
					    16) == 0);
				}

				default:
				{
					return false;
				}
			}
		}

		static struct sockaddr_storage CopyAddress(const struct sockaddr* addr)
		{
			struct sockaddr_storage copiedAddr
			{
			};

			switch (addr->sa_family)
			{
				case AF_INET:
					std::memcpy(std::addressof(copiedAddr), addr, sizeof(struct sockaddr_in));
					break;

				case AF_INET6:
					std::memcpy(std::addressof(copiedAddr), addr, sizeof(struct sockaddr_in6));
					break;
			}

			return copiedAddr;
		}

		static void NormalizeIp(std::string& ip);
	};

	class File
	{
	public:
		static void CheckFile(const char* file);
	};

	class Byte
	{
	public:
		/**
		 * Getters below get value in Host Byte Order.
		 * Setters below set value in Network Byte Order.
		 */
		static uint8_t Get1Byte(const uint8_t* data, size_t i)
		{
			return data[i];
		}

		static uint16_t Get2Bytes(const uint8_t* data, size_t i)
		{
			return uint16_t{ data[i + 1] } | uint16_t{ data[i] } << 8;
		}

		static uint32_t Get3Bytes(const uint8_t* data, size_t i)
		{
			return uint32_t{ data[i + 2] } | uint32_t{ data[i + 1] } << 8 | uint32_t{ data[i] } << 16;
		}

		static int32_t Get3BytesSigned(const uint8_t* data, size_t i)
		{
			auto byte2 = data[i]; // The most significant byte.
			auto byte1 = data[i + 1];
			auto byte0 = data[i + 2]; // The less significant byte.

			// Check bit 7 (sign).
			const uint8_t extension = byte2 & 0b10000000 ? 0b11111111 : 0b00000000;

			return int32_t{ byte0 } | (int32_t{ byte1 } << 8) | (int32_t{ byte2 } << 16) |
			       (int32_t{ extension } << 24);
		}

		static uint32_t Get4Bytes(const uint8_t* data, size_t i)
		{
			return uint32_t{ data[i + 3] } | uint32_t{ data[i + 2] } << 8 |
			       uint32_t{ data[i + 1] } << 16 | uint32_t{ data[i] } << 24;
		}

		static uint64_t Get8Bytes(const uint8_t* data, size_t i)
		{
			return uint64_t{ Byte::Get4Bytes(data, i) } << 32 | Byte::Get4Bytes(data, i + 4);
		}

		static void Set1Byte(uint8_t* data, size_t i, uint8_t value)
		{
			data[i] = value;
		}

		static void Set2Bytes(uint8_t* data, size_t i, uint16_t value)
		{
			data[i + 1] = static_cast<uint8_t>(value);
			data[i]     = static_cast<uint8_t>(value >> 8);
		}

		static void Set3Bytes(uint8_t* data, size_t i, uint32_t value)
		{
			data[i + 2] = static_cast<uint8_t>(value);
			data[i + 1] = static_cast<uint8_t>(value >> 8);
			data[i]     = static_cast<uint8_t>(value >> 16);
		}

		static void Set3BytesSigned(uint8_t* data, size_t i, int32_t value)
		{
			data[i + 2] = static_cast<int8_t>(value);
			data[i + 1] = static_cast<uint8_t>(value >> 8);
			data[i]     = static_cast<uint8_t>(value >> 16);
		}

		static void Set4Bytes(uint8_t* data, size_t i, uint32_t value)
		{
			data[i + 3] = static_cast<uint8_t>(value);
			data[i + 2] = static_cast<uint8_t>(value >> 8);
			data[i + 1] = static_cast<uint8_t>(value >> 16);
			data[i]     = static_cast<uint8_t>(value >> 24);
		}

		static void Set8Bytes(uint8_t* data, size_t i, uint64_t value)
		{
			data[i + 7] = static_cast<uint8_t>(value);
			data[i + 6] = static_cast<uint8_t>(value >> 8);
			data[i + 5] = static_cast<uint8_t>(value >> 16);
			data[i + 4] = static_cast<uint8_t>(value >> 24);
			data[i + 3] = static_cast<uint8_t>(value >> 32);
			data[i + 2] = static_cast<uint8_t>(value >> 40);
			data[i + 1] = static_cast<uint8_t>(value >> 48);
			data[i]     = static_cast<uint8_t>(value >> 56);
		}

		template<typename T>
		typename std::enable_if<std::is_unsigned<T>::value, bool>::type static IsPaddedTo4Bytes(T size)
		{
			return (size & 0x03) == 0u;
		}

		template<typename T>
		typename std::enable_if<std::is_unsigned<T>::value, T>::type static PadTo4Bytes(T size)
		{
			return (size + 3) & ~static_cast<T>(0x03);
		}
	};

	class Bits
	{
	public:
		static size_t CountSetBits(const uint16_t mask)
		{
			return static_cast<size_t>(__builtin_popcount(mask));
		}
	};

	class Crypto
	{
	public:
		static void ClassInit();
		static void ClassDestroy();

		static uint32_t GetRandomUInt(uint32_t min, uint32_t max)
		{
			// NOTE: This is the original, but produces very small values.
			// Crypto::seed = (214013 * Crypto::seed) + 2531011;
			// return (((Crypto::seed>>16)&0x7FFF) % (max - min + 1)) + min;

			// This seems to produce better results.
			Crypto::seed = uint32_t{ ((214013 * Crypto::seed) + 2531011) };

			// Special case.
			if (max == 4294967295)
			{
				--max;
			}

			if (min > max)
			{
				min = max;
			}

			return (((Crypto::seed >> 4) & 0x7FFF7FFF) % (max - min + 1)) + min;
		}

		static std::string GetRandomString(size_t len)
		{
			char buffer[64];
			static const char Chars[] = { '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b',
				                            'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n',
				                            'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z' };

			if (len > 64)
			{
				len = 64;
			}

			for (size_t i{ 0 }; i < len; ++i)
			{
				buffer[i] = Chars[GetRandomUInt(0, sizeof(Chars) - 1)];
			}

			return { buffer, len };
		}

		static uint32_t GetCRC32(const uint8_t* data, size_t size);

		static uint32_t GetCRC32c(const uint8_t* data, size_t size);

		static const uint8_t* GetHmacSha1(const std::string& key, const uint8_t* data, size_t len);

	private:
		thread_local static uint32_t seed;
		thread_local static EVP_MAC* mac;
		thread_local static EVP_MAC_CTX* hmacSha1Ctx;
		thread_local static uint8_t hmacSha1Buffer[];
		static const uint32_t Crc32Table[256];
		static const uint32_t Crc32cTable[256];
	};

	class String
	{
	public:
		static void ToLowerCase(std::string& str)
		{
			std::transform(str.begin(), str.end(), str.begin(), ::tolower);
		}

		static std::string Base64Encode(const uint8_t* data, size_t len);

		static std::string Base64Encode(const std::string& str);

		static uint8_t* Base64Decode(const uint8_t* data, size_t len, size_t& outLen);

		static uint8_t* Base64Decode(const std::string& str, size_t& outLen);
	};

	// T is the base type (uint16_t, uint32_t, ...).
	// N is the max number of bits used in T.
	template<typename T, uint8_t N = 0>
	class Number
	{
	private:
		static constexpr T MaxValue = (N == 0) ? std::numeric_limits<T>::max() : ((1 << N) - 1);
		static constexpr T Mask =
		  (N == 0) ? std::numeric_limits<T>::max() : (static_cast<T>((T(1) << N) - 1));

	public:
		static bool IsEqualThan(T lhs, T rhs)
		{
			static_assert(
			  std::is_same_v<T, uint8_t> || std::is_same_v<T, uint16_t> || std::is_same_v<T, uint32_t> ||
			    std::is_same_v<T, uint64_t>,
			  "T must be uint8_t, uint16_t, uint32_t or uint64_t");

			lhs &= Mask;
			rhs &= Mask;

			return (lhs == rhs);
		}

		static bool IsHigherThan(T lhs, T rhs)
		{
			static_assert(
			  std::is_same_v<T, uint8_t> || std::is_same_v<T, uint16_t> || std::is_same_v<T, uint32_t> ||
			    std::is_same_v<T, uint64_t>,
			  "T must be uint8_t, uint16_t, uint32_t or uint64_t");

			lhs &= Mask;
			rhs &= Mask;

			return ((lhs > rhs) && (lhs - rhs <= MaxValue / 2)) ||
			       ((rhs > lhs) && (rhs - lhs > MaxValue / 2));
		}

		static bool IsLowerThan(T lhs, T rhs)
		{
			static_assert(
			  std::is_same_v<T, uint8_t> || std::is_same_v<T, uint16_t> || std::is_same_v<T, uint32_t> ||
			    std::is_same_v<T, uint64_t>,
			  "T must be uint8_t, uint16_t, uint32_t or uint64_t");

			lhs &= Mask;
			rhs &= Mask;

			return ((rhs > lhs) && (rhs - lhs <= MaxValue / 2)) ||
			       ((lhs > rhs) && (lhs - rhs > MaxValue / 2));
		}

		static bool IsHigherOrEqualThan(T lhs, T rhs)
		{
			static_assert(
			  std::is_same_v<T, uint8_t> || std::is_same_v<T, uint16_t> || std::is_same_v<T, uint32_t> ||
			    std::is_same_v<T, uint64_t>,
			  "T must be uint8_t, uint16_t, uint32_t or uint64_t");

			lhs &= Mask;
			rhs &= Mask;

			return (lhs == rhs) || ((lhs > rhs) && (lhs - rhs <= MaxValue / 2)) ||
			       ((rhs > lhs) && (rhs - lhs > MaxValue / 2));
		}

		static bool IsLowerOrEqualThan(T lhs, T rhs)
		{
			static_assert(
			  std::is_same_v<T, uint8_t> || std::is_same_v<T, uint16_t> || std::is_same_v<T, uint32_t> ||
			    std::is_same_v<T, uint64_t>,
			  "T must be uint8_t, uint16_t, uint32_t or uint64_t");

			lhs &= Mask;
			rhs &= Mask;

			return (lhs == rhs) || ((rhs > lhs) && (rhs - lhs <= MaxValue / 2)) ||
			       ((lhs > rhs) && (lhs - rhs > MaxValue / 2));
		}
	};

	class Time
	{
	private:
		// Seconds from Jan 1, 1900 to Jan 1, 1970.
		static constexpr uint32_t UnixNtpOffset{ 0x83AA7E80 };
		// NTP fractional unit.
		static constexpr uint64_t NtpFractionalUnit{ 1LL << 32 };

	public:
		struct Ntp
		{
			uint32_t seconds;
			uint32_t fractions;
		};

		static Time::Ntp TimeMs2Ntp(uint64_t ms)
		{
			Time::Ntp ntp{}; // NOLINT(cppcoreguidelines-pro-type-member-init)

			ntp.seconds = ms / 1000;
			ntp.fractions =
			  static_cast<uint32_t>((static_cast<double>(ms % 1000) / 1000) * NtpFractionalUnit);

			return ntp;
		}

		static uint64_t Ntp2TimeMs(Time::Ntp ntp)
		{
			return (
			  static_cast<uint64_t>(ntp.seconds) * 1000 +
			  static_cast<uint64_t>(
			    std::round((static_cast<double>(ntp.fractions) * 1000) / NtpFractionalUnit)));
		}

		static uint32_t TimeMsToAbsSendTime(uint64_t ms)
		{
			return static_cast<uint32_t>(((ms << 18) + 500) / 1000) & 0x00FFFFFF;
		}
	};

	class BitStream
	{
	public:
		BitStream(uint8_t* data, size_t len);
		~BitStream() = default;

		const uint8_t* GetData() const;
		size_t GetLength() const;
		uint32_t GetOffset() const;
		void Reset();
		uint8_t GetBit();
		uint32_t GetBits(size_t count);
		uint32_t GetLeftBits() const;
		uint32_t GetNumBits(uint32_t n) const;
		std::optional<uint32_t> ReadNs(uint32_t n);
		void SkipBits(size_t count);
		void Write(uint32_t offset, uint32_t n, uint32_t v);
		void PutBit(uint8_t bit);
		void PutBits(uint32_t count, uint32_t bits);

	private:
		void PutBit(uint32_t offset, uint8_t bit);
		void PutBits(uint32_t offset, uint32_t count, uint32_t bits);

	private:
		uint8_t data[RTC::Consts::TwoBytesRtpExtensionMaxLength];
		uint32_t len{ 0 };
		uint32_t offset{ 0 };
	};

} // namespace Utils

#endif
