//===----------------------------------------------------------------------===//
//
// Part of libcu++, the C++ Standard Library for your entire system,
// under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
//
//===----------------------------------------------------------------------===//

// MIT License
//
// Modifications Copyright (C) 2025-2026 Advanced Micro Devices, Inc. All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
// SOFTWARE.

#ifndef _CUDA_STD_NUMBERS
#define _CUDA_STD_NUMBERS

#include <cuda/std/detail/__config>

#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
#  pragma GCC system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
#  pragma clang system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
#  pragma system_header
#endif // no system header

#include <cuda/std/__floating_point/nvfp_types.h>
#include <cuda/std/__type_traits/enable_if.h>
#include <cuda/std/__type_traits/is_floating_point.h>
#include <cuda/std/version>

#ifdef __HIP_PLATFORM_AMD__
#include <amd/amd_utils.h>
#endif

_LIBCUDACXX_BEGIN_NAMESPACE_STD

template <class _Tp>
struct __numbers_ill_formed_impl : false_type
{};

template <class _Tp, class = void>
struct __numbers
{
  static_assert(__numbers_ill_formed_impl<_Tp>::value,
                "[math.constants] A program that instantiates a primary template of a mathematical constant variable "
                "template is ill-formed.");
};

_CCCL_DIAG_PUSH
_CCCL_DIAG_SUPPRESS_MSVC(4305) // truncation from 'double' to 'const _Tp'

template <class _Tp>
struct __numbers<_Tp, enable_if_t<_CCCL_TRAIT(is_floating_point, _Tp)>>
{
  static _LIBCUDACXX_HIDE_FROM_ABI constexpr _Tp __e() noexcept
  {
    return 2.718281828459045235360287471352662;
  }
  static _LIBCUDACXX_HIDE_FROM_ABI constexpr _Tp __log2e() noexcept
  {
    return 1.442695040888963407359924681001892;
  }
  static _LIBCUDACXX_HIDE_FROM_ABI constexpr _Tp __log10e() noexcept
  {
    return 0.434294481903251827651128918916605;
  }
  static _LIBCUDACXX_HIDE_FROM_ABI constexpr _Tp __pi() noexcept
  {
    return 3.141592653589793238462643383279502;
  }
  static _LIBCUDACXX_HIDE_FROM_ABI constexpr _Tp __inv_pi() noexcept
  {
    return 0.318309886183790671537767526745028;
  }
  static _LIBCUDACXX_HIDE_FROM_ABI constexpr _Tp __inv_sqrtpi() noexcept
  {
    return 0.564189583547756286948079451560772;
  }
  static _LIBCUDACXX_HIDE_FROM_ABI constexpr _Tp __ln2() noexcept
  {
    return 0.693147180559945309417232121458176;
  }
  static _LIBCUDACXX_HIDE_FROM_ABI constexpr _Tp __ln10() noexcept
  {
    return 2.302585092994045684017991454684364;
  }
  static _LIBCUDACXX_HIDE_FROM_ABI constexpr _Tp __sqrt2() noexcept
  {
    return 1.414213562373095048801688724209698;
  }
  static _LIBCUDACXX_HIDE_FROM_ABI constexpr _Tp __sqrt3() noexcept
  {
    return 1.732050807568877293527446341505872;
  }
  static _LIBCUDACXX_HIDE_FROM_ABI constexpr _Tp __inv_sqrt3() noexcept
  {
    return 0.577350269189625764509148780501957;
  }
  static _LIBCUDACXX_HIDE_FROM_ABI constexpr _Tp __egamma() noexcept
  {
    return 0.577215664901532860606512090082402;
  }
  static _LIBCUDACXX_HIDE_FROM_ABI constexpr _Tp __phi() noexcept
  {
    return 1.618033988749894848204586834365638;
  }
};

_CCCL_DIAG_POP

#if _LIBCUDACXX_HAS_NVFP16()
// NOTE(HIP/AMD): for ROCm 7.10 and earlier constexpression setting of __half values is not possible
#  if LIBHIPCXX_ROCM_VERSION_GE(7, 11, 0)
template <>
struct __numbers<__half>
{
  static _LIBCUDACXX_HIDE_FROM_ABI constexpr __half __e() noexcept
  {
    // TODO(HIP/AMD): remove this WAR once __half_raw has only the x member
    // This WAR is necessary as otherwise the data member will be active and the
    // passed value will be interpreted as a FLOAT16 instead of an unsigned int
    return __half_raw{.x = 0x4170u};
  }
  static _LIBCUDACXX_HIDE_FROM_ABI constexpr __half __log2e() noexcept
  {
    // TODO(HIP/AMD): remove this WAR once __half_raw has only the x member
    // This WAR is necessary as otherwise the data member will be active and the
    // passed value will be interpreted as a FLOAT16 instead of an unsigned int
    return __half_raw{.x = 0x3dc5u};
  }
  static _LIBCUDACXX_HIDE_FROM_ABI constexpr __half __log10e() noexcept
  {
    // TODO(HIP/AMD): remove this WAR once __half_raw has only the x member
    // This WAR is necessary as otherwise the data member will be active and the
    // passed value will be interpreted as a FLOAT16 instead of an unsigned int
    return __half_raw{.x = 0x36f3u};
  }
  static _LIBCUDACXX_HIDE_FROM_ABI constexpr __half __pi() noexcept
  {
    // TODO(HIP/AMD): remove this WAR once __half_raw has only the x member
    // This WAR is necessary as otherwise the data member will be active and the
    // passed value will be interpreted as a FLOAT16 instead of an unsigned int
    return __half_raw{.x = 0x4248u};
  }
  static _LIBCUDACXX_HIDE_FROM_ABI constexpr __half __inv_pi() noexcept
  {
    // TODO(HIP/AMD): remove this WAR once __half_raw has only the x member
    // This WAR is necessary as otherwise the data member will be active and the
    // passed value will be interpreted as a FLOAT16 instead of an unsigned int
    return __half_raw{.x = 0x3518u};
  }
  static _LIBCUDACXX_HIDE_FROM_ABI constexpr __half __inv_sqrtpi() noexcept
  {
    // TODO(HIP/AMD): remove this WAR once __half_raw has only the x member
    // This WAR is necessary as otherwise the data member will be active and the
    // passed value will be interpreted as a FLOAT16 instead of an unsigned int
    return __half_raw{.x = 0x3883u};
  }
  static _LIBCUDACXX_HIDE_FROM_ABI constexpr __half __ln2() noexcept
  {
    // TODO(HIP/AMD): remove this WAR once __half_raw has only the x member
    // This WAR is necessary as otherwise the data member will be active and the
    // passed value will be interpreted as a FLOAT16 instead of an unsigned int
    return __half_raw{.x = 0x398cu};
  }
  static _LIBCUDACXX_HIDE_FROM_ABI constexpr __half __ln10() noexcept
  {
    // TODO(HIP/AMD): remove this WAR once __half_raw has only the x member
    // This WAR is necessary as otherwise the data member will be active and the
    // passed value will be interpreted as a FLOAT16 instead of an unsigned int
    return __half_raw{.x = 0x409bu};
  }
  static _LIBCUDACXX_HIDE_FROM_ABI constexpr __half __sqrt2() noexcept
  {
    // TODO(HIP/AMD): remove this WAR once __half_raw has only the x member
    // This WAR is necessary as otherwise the data member will be active and the
    // passed value will be interpreted as a FLOAT16 instead of an unsigned int
    return __half_raw{.x = 0x3da8u};
  }
  static _LIBCUDACXX_HIDE_FROM_ABI constexpr __half __sqrt3() noexcept
  {
    // TODO(HIP/AMD): remove this WAR once __half_raw has only the x member
    // This WAR is necessary as otherwise the data member will be active and the
    // passed value will be interpreted as a FLOAT16 instead of an unsigned int
    return __half_raw{.x = 0x3eeeu};
  }
  static _LIBCUDACXX_HIDE_FROM_ABI constexpr __half __inv_sqrt3() noexcept
  {
    // TODO(HIP/AMD): remove this WAR once __half_raw has only the x member
    // This WAR is necessary as otherwise the data member will be active and the
    // passed value will be interpreted as a FLOAT16 instead of an unsigned int
    return __half_raw{.x = 0x389eu};
  }
  static _LIBCUDACXX_HIDE_FROM_ABI constexpr __half __egamma() noexcept
  {
    // TODO(HIP/AMD): remove this WAR once __half_raw has only the x member
    // This WAR is necessary as otherwise the data member will be active and the
    // passed value will be interpreted as a FLOAT16 instead of an unsigned int
    return __half_raw{.x = 0x389eu};
  }
  static _LIBCUDACXX_HIDE_FROM_ABI constexpr __half __phi() noexcept
  {
    // TODO(HIP/AMD): remove this WAR once __half_raw has only the x member
    // This WAR is necessary as otherwise the data member will be active and the
    // passed value will be interpreted as a FLOAT16 instead of an unsigned int
    return __half_raw{.x = 0x3e79u};
  }
};
#  endif // LIBHIPCXX_ROCM_VERSION_GE(7, 11, 0)
#endif // _LIBCUDACXX_HAS_NVFP16()

#if _LIBCUDACXX_HAS_NVBF16()
template <>
struct __numbers<__hip_bfloat16>
{
  static _LIBCUDACXX_HIDE_FROM_ABI constexpr __hip_bfloat16 __e() noexcept
  {
    return __hip_bfloat16_raw{0x402eu};
  }
  static _LIBCUDACXX_HIDE_FROM_ABI constexpr __hip_bfloat16 __log2e() noexcept
  {
    return __hip_bfloat16_raw{0x3fb9u};
  }
  static _LIBCUDACXX_HIDE_FROM_ABI constexpr __hip_bfloat16 __log10e() noexcept
  {
    return __hip_bfloat16_raw{0x3edeu};
  }
  static _LIBCUDACXX_HIDE_FROM_ABI constexpr __hip_bfloat16 __pi() noexcept
  {
    return __hip_bfloat16_raw{0x4049u};
  }
  static _LIBCUDACXX_HIDE_FROM_ABI constexpr __hip_bfloat16 __inv_pi() noexcept
  {
    return __hip_bfloat16_raw{0x3ea3u};
  }
  static _LIBCUDACXX_HIDE_FROM_ABI constexpr __hip_bfloat16 __inv_sqrtpi() noexcept
  {
    return __hip_bfloat16_raw{0x3f10u};
  }
  static _LIBCUDACXX_HIDE_FROM_ABI constexpr __hip_bfloat16 __ln2() noexcept
  {
    return __hip_bfloat16_raw{0x3f31u};
  }
  static _LIBCUDACXX_HIDE_FROM_ABI constexpr __hip_bfloat16 __ln10() noexcept
  {
    return __hip_bfloat16_raw{0x4013u};
  }
  static _LIBCUDACXX_HIDE_FROM_ABI constexpr __hip_bfloat16 __sqrt2() noexcept
  {
    return __hip_bfloat16_raw{0x3fb5u};
  }
  static _LIBCUDACXX_HIDE_FROM_ABI constexpr __hip_bfloat16 __sqrt3() noexcept
  {
    return __hip_bfloat16_raw{0x3fdeu};
  }
  static _LIBCUDACXX_HIDE_FROM_ABI constexpr __hip_bfloat16 __inv_sqrt3() noexcept
  {
    return __hip_bfloat16_raw{0x3f14u};
  }
  static _LIBCUDACXX_HIDE_FROM_ABI constexpr __hip_bfloat16 __egamma() noexcept
  {
    return __hip_bfloat16_raw{0x3f14u};
  }
  static _LIBCUDACXX_HIDE_FROM_ABI constexpr __hip_bfloat16 __phi() noexcept
  {
    return __hip_bfloat16_raw{0x3fcfu};
  }
};
#endif // _LIBCUDACXX_HAS_NVBF16()

namespace numbers
{

template <class _Tp>
inline constexpr _Tp e_v = __numbers<_Tp>::__e();
template <class _Tp>
inline constexpr _Tp log2e_v = __numbers<_Tp>::__log2e();
template <class _Tp>
inline constexpr _Tp log10e_v = __numbers<_Tp>::__log10e();
template <class _Tp>
inline constexpr _Tp pi_v = __numbers<_Tp>::__pi();
template <class _Tp>
inline constexpr _Tp inv_pi_v = __numbers<_Tp>::__inv_pi();
template <class _Tp>
inline constexpr _Tp inv_sqrtpi_v = __numbers<_Tp>::__inv_sqrtpi();
template <class _Tp>
inline constexpr _Tp ln2_v = __numbers<_Tp>::__ln2();
template <class _Tp>
inline constexpr _Tp ln10_v = __numbers<_Tp>::__ln10();
template <class _Tp>
inline constexpr _Tp sqrt2_v = __numbers<_Tp>::__sqrt2();
template <class _Tp>
inline constexpr _Tp sqrt3_v = __numbers<_Tp>::__sqrt3();
template <class _Tp>
inline constexpr _Tp inv_sqrt3_v = __numbers<_Tp>::__inv_sqrt3();
template <class _Tp>
inline constexpr _Tp egamma_v = __numbers<_Tp>::__egamma();
template <class _Tp>
inline constexpr _Tp phi_v = __numbers<_Tp>::__phi();

#if !_CCCL_COMPILER(MSVC)
// MSVC errors here because of "error: A __device__ variable template cannot have a const qualified type on Windows"
#  if _LIBCUDACXX_HAS_NVFP16()
// NOTE(HIP/AMD): for ROCm 7.10 and earlier constexpression setting of __half values is not possible
#    if LIBHIPCXX_ROCM_VERSION_GE(7, 11, 0)
template <>
_CCCL_GLOBAL_CONSTANT __half e_v<__half> = __numbers<__half>::__e();
template <>
_CCCL_GLOBAL_CONSTANT __half log2e_v<__half> = __numbers<__half>::__log2e();
template <>
_CCCL_GLOBAL_CONSTANT __half log10e_v<__half> = __numbers<__half>::__log10e();
template <>
_CCCL_GLOBAL_CONSTANT __half pi_v<__half> = __numbers<__half>::__pi();
template <>
_CCCL_GLOBAL_CONSTANT __half inv_pi_v<__half> = __numbers<__half>::__inv_pi();
template <>
_CCCL_GLOBAL_CONSTANT __half inv_sqrtpi_v<__half> = __numbers<__half>::__inv_sqrtpi();
template <>
_CCCL_GLOBAL_CONSTANT __half ln2_v<__half> = __numbers<__half>::__ln2();
template <>
_CCCL_GLOBAL_CONSTANT __half ln10_v<__half> = __numbers<__half>::__ln10();
template <>
_CCCL_GLOBAL_CONSTANT __half sqrt2_v<__half> = __numbers<__half>::__sqrt2();
template <>
_CCCL_GLOBAL_CONSTANT __half sqrt3_v<__half> = __numbers<__half>::__sqrt3();
template <>
_CCCL_GLOBAL_CONSTANT __half inv_sqrt3_v<__half> = __numbers<__half>::__inv_sqrt3();
template <>
_CCCL_GLOBAL_CONSTANT __half egamma_v<__half> = __numbers<__half>::__egamma();
template <>
_CCCL_GLOBAL_CONSTANT __half phi_v<__half> = __numbers<__half>::__phi();
#    endif // LIBHIPCXX_ROCM_VERSION_GE(7, 11, 0)
#  endif // _LIBCUDACXX_HAS_NVFP16()

#  if _LIBCUDACXX_HAS_NVBF16()
template <>
_CCCL_GLOBAL_CONSTANT __hip_bfloat16 e_v<__hip_bfloat16> = __numbers<__hip_bfloat16>::__e();
template <>
_CCCL_GLOBAL_CONSTANT __hip_bfloat16 log2e_v<__hip_bfloat16> = __numbers<__hip_bfloat16>::__log2e();
template <>
_CCCL_GLOBAL_CONSTANT __hip_bfloat16 log10e_v<__hip_bfloat16> = __numbers<__hip_bfloat16>::__log10e();
template <>
_CCCL_GLOBAL_CONSTANT __hip_bfloat16 pi_v<__hip_bfloat16> = __numbers<__hip_bfloat16>::__pi();
template <>
_CCCL_GLOBAL_CONSTANT __hip_bfloat16 inv_pi_v<__hip_bfloat16> = __numbers<__hip_bfloat16>::__inv_pi();
template <>
_CCCL_GLOBAL_CONSTANT __hip_bfloat16 inv_sqrtpi_v<__hip_bfloat16> = __numbers<__hip_bfloat16>::__inv_sqrtpi();
template <>
_CCCL_GLOBAL_CONSTANT __hip_bfloat16 ln2_v<__hip_bfloat16> = __numbers<__hip_bfloat16>::__ln2();
template <>
_CCCL_GLOBAL_CONSTANT __hip_bfloat16 ln10_v<__hip_bfloat16> = __numbers<__hip_bfloat16>::__ln10();
template <>
_CCCL_GLOBAL_CONSTANT __hip_bfloat16 sqrt2_v<__hip_bfloat16> = __numbers<__hip_bfloat16>::__sqrt2();
template <>
_CCCL_GLOBAL_CONSTANT __hip_bfloat16 sqrt3_v<__hip_bfloat16> = __numbers<__hip_bfloat16>::__sqrt3();
template <>
_CCCL_GLOBAL_CONSTANT __hip_bfloat16 inv_sqrt3_v<__hip_bfloat16> = __numbers<__hip_bfloat16>::__inv_sqrt3();
template <>
_CCCL_GLOBAL_CONSTANT __hip_bfloat16 egamma_v<__hip_bfloat16> = __numbers<__hip_bfloat16>::__egamma();
template <>
_CCCL_GLOBAL_CONSTANT __hip_bfloat16 phi_v<__hip_bfloat16> = __numbers<__hip_bfloat16>::__phi();
#  endif // _LIBCUDACXX_HAS_NVBF16()
#endif // !_CCCL_COMPILER(MSVC)

inline constexpr double e          = __numbers<double>::__e();
inline constexpr double log2e      = __numbers<double>::__log2e();
inline constexpr double log10e     = __numbers<double>::__log10e();
inline constexpr double pi         = __numbers<double>::__pi();
inline constexpr double inv_pi     = __numbers<double>::__inv_pi();
inline constexpr double inv_sqrtpi = __numbers<double>::__inv_sqrtpi();
inline constexpr double ln2        = __numbers<double>::__ln2();
inline constexpr double ln10       = __numbers<double>::__ln10();
inline constexpr double sqrt2      = __numbers<double>::__sqrt2();
inline constexpr double sqrt3      = __numbers<double>::__sqrt3();
inline constexpr double inv_sqrt3  = __numbers<double>::__inv_sqrt3();
inline constexpr double egamma     = __numbers<double>::__egamma();
inline constexpr double phi        = __numbers<double>::__phi();

} // namespace numbers

_LIBCUDACXX_END_NAMESPACE_STD

#endif // _CUDA_STD_NUMBERS
