import libtbx.load_env
import os
import subprocess
Import("env_etc")
from shutil import copy, which

if env_etc.enable_kokkos:

  # libkokkos.a
  # call kokkos build system directly
  # set environment variable defaults if necessary
  if os.getenv('KOKKOS_DEVICES') is None:
    os.environ['KOKKOS_DEVICES'] = "OpenMP"
  use_openmp = "OpenMP" in os.getenv('KOKKOS_DEVICES')
  use_cuda = "Cuda" in os.getenv('KOKKOS_DEVICES')
  use_hip = "HIP" in os.getenv('KOKKOS_DEVICES')
  use_sycl = "SYCL" in os.getenv('KOKKOS_DEVICES')
  if os.getenv('KOKKOS_PATH') is None:
    os.environ['KOKKOS_PATH'] = libtbx.env.under_dist('simtbx', '../../kokkos')
  if os.getenv('KOKKOSKERNELS_PATH') is None:
    os.environ['KOKKOSKERNELS_PATH'] = libtbx.env.under_dist('simtbx', '../../kokkos-kernels')
  if os.getenv('KOKKOS_ARCH') is None:
    os.environ['KOKKOS_ARCH'] = "HSW"
  if use_cuda and os.getenv('KOKKOS_CUDA_OPTIONS') is None:
    os.environ['KOKKOS_CUDA_OPTIONS'] = "enable_lambda"
  os.environ['CXXFLAGS'] = '-O3 -fPIC -DCUDAREAL=double'

  library_flags = "-Llib"
  if use_cuda:
    library_flags += " -L$(CUDA_HOME)/lib64 -L$(CUDA_HOME)/lib64/stubs"
    if os.getenv('CUDA_COMPATIBILITY'):
      library_flags += " -L$(CUDA_HOME)/compat"
  os.environ['LDFLAGS'] = library_flags

  linked_libraries = "-lkokkoscontainers -lkokkoscore -ldl"
  if use_cuda:
    linked_libraries += " -lcudart -lcuda"
  os.environ['LDLIBS'] = linked_libraries

  original_cxx = None
  kokkos_lib = 'libkokkos.a'
  kokkos_cxxflags = None

  if os.getenv('CXX') is not None:
    original_cxx = os.environ['CXX']
  if use_cuda:
    os.environ['CXX'] = os.path.join(os.environ['KOKKOS_PATH'], 'bin', 'nvcc_wrapper')
  elif use_hip:
    os.environ['CXX'] = 'hipcc'
  else:
    os.environ['CXX'] = os.environ.get('CXX', 'g++')

  # remove conda/lib path from ld_path
  original_ld_path = os.getenv('LD_LIBRARY_PATH')
  if original_ld_path is not None:
    ld_path = [s for s in original_ld_path.split(':') if not 'opt/mamba/envs/psana_env/lib' in s]
    os.environ['LD_LIBRARY_PATH'] = ":".join(ld_path)

  print('='*79)
  print('Building Kokkos')
  print('-'*79)
  returncode = subprocess.call(['make', '-f', 'Makefile.kokkos', kokkos_lib],
                                cwd=os.environ['KOKKOS_PATH'])
  print()

  print('Copying Kokkos library')
  print('-'*79)
  src = os.path.join(os.environ['KOKKOS_PATH'], kokkos_lib)
  dst = os.path.join(libtbx.env.under_build('lib'), kokkos_lib)
  if os.path.isfile(src):
    copy(src, dst)
    print('Copied')
    print('  source:     ', src)
    print('  destination:', dst)
  else:
    print('Error: {src} does not exist'.format(src=src))
  print()

  # =============================================================================
  # Build kokkos with CMake
  # The build needs to be in a directory not in the build directory, otherwise
  # kokkos-kernels will find that build directory instead of build/lib/cmake/Kokkos
  # The error will look like
  #
  #   -- The project name is: KokkosKernels
  #   CMake Error at /dev/shm/bkpoon/software/xfel/build/kokkos/KokkosConfig.cmake:48 (INCLUDE):
  #     INCLUDE could not find requested file:

  #       /dev/shm/bkpoon/software/xfel/build/kokkos/KokkosTargets.cmake
  #   Call Stack (most recent call first):
  #     CMakeLists.txt:107 (FIND_PACKAGE)
  #
  # kokkos will be installed in build with the libraries in lib, not lib64
  #
  # TODO:
  # - Change over from Makefile version to cmake version
  # - Change system configuration classes to select backends with cmake
  #   flags instead of environment variables
  # - Verify that libraries from cmake version works the same as libkokkos.a

  OnOff = {True:'ON', False:'OFF'}
  supported_architectures = ['Native', 'HSW', 'Zen', 'Zen2', 'Zen3', 'Volta70', 'Turing75', 'Ampere80', 'Vega908', 'Vega90A', 'XeHP', 'PVC']
  architectures = {}
  for arch in supported_architectures:
    architectures[arch] = OnOff[ arch in os.getenv('KOKKOS_ARCH') ]

  cmake_is_available = which('cmake')
  if cmake_is_available:
    print('='*79)
    print('Building kokkos with cmake')
    print('-'*79)
    kokkos_build_dir = libtbx.env.under_dist('simtbx', '../../kokkos_build')
    if not os.path.isdir(kokkos_build_dir):
      os.mkdir(kokkos_build_dir)
    returncode = subprocess.call([
        'cmake',
        os.environ['KOKKOS_PATH'],
        '-DCMAKE_CXX_STANDARD={}'.format('17'),
        '-DCMAKE_INSTALL_PREFIX={}'.format(libtbx.env.under_build('.')),
        '-DCMAKE_INSTALL_LIBDIR=lib',
        '-DBUILD_SHARED_LIBS={}'.format(OnOff[True]),
        '-DKokkos_ARCH_NATIVE={}'.format(architectures['Native']),
        '-DKokkos_ARCH_HSW={}'.format(architectures['HSW']),
        '-DKokkos_ARCH_ZEN={}'.format(architectures['Zen']),
        '-DKokkos_ARCH_ZEN2={}'.format(architectures['Zen2']),
        '-DKokkos_ARCH_ZEN3={}'.format(architectures['Zen3']),
        '-DKokkos_ARCH_VOLTA70={}'.format(architectures['Volta70']),
        '-DKokkos_ARCH_TURING75={}'.format(architectures['Turing75']),
        '-DKokkos_ARCH_AMPERE80={}'.format(architectures['Ampere80']),
        '-DKokkos_ARCH_VEGA908={}'.format(architectures['Vega908']),
        '-DKokkos_ARCH_VEGA90A={}'.format(architectures['Vega90A']),
        '-DKokkos_ARCH_INTEL_XEHP={}'.format(architectures['XeHP']),
        '-DKokkos_ARCH_INTEL_PVC={}'.format(architectures['PVC']),
        '-DKokkos_ENABLE_SERIAL=ON',
        '-DKokkos_ENABLE_OPENMP={}'.format(OnOff[use_openmp]),
        '-DKokkos_ENABLE_CUDA={}'.format(OnOff[use_cuda]),
        '-DKokkos_ENABLE_HIP={}'.format(OnOff[use_hip]),
        '-DKokkos_ENABLE_SYCL={}'.format(OnOff[use_sycl]),
        '-DKokkos_ENABLE_IMPL_MDSPAN=ON'
      ],
      cwd=kokkos_build_dir)

    returncode = subprocess.call(['make', '-j', '4', 'install'], cwd=kokkos_build_dir)

  # -----------------------------------------------------------------------------
  # Build kokkos-kernels with CMake.
  # Turn off all ETI builds for now, until needed, for maximum machine compatibility
    print('='*79)
    print('Building kokkos_kernels')
    print('-'*79)
    kokkos_kernels_build_dir = libtbx.env.under_dist('simtbx', '../../kokkos-kernels/build')
    if not os.path.isdir(kokkos_kernels_build_dir):
      os.mkdir(kokkos_kernels_build_dir)
    returncode = subprocess.call([
        'cmake',
        os.environ['KOKKOSKERNELS_PATH'],
        '-DCMAKE_INSTALL_PREFIX={}'.format(libtbx.env.under_build('.')),
        '-DCMAKE_INSTALL_LIBDIR=lib',
        '-DKokkos_ROOT={}'.format(libtbx.env.under_build('.')),
        '-DKokkosKernels_ADD_DEFAULT_ETI={}'.format(OnOff[False]),
        '-DKokkosKernels_INST_LAYOUTLEFT:BOOL={}'.format(OnOff[False]),
        '-DKokkosKernels_INST_LAYOUTRIGHT:BOOL={}'.format(OnOff[False]),
        '-DKokkosKernels_ENABLE_TPL_CUBLAS={}'.format(OnOff[False]),
        '-DKokkosKernels_ENABLE_TPL_CUSPARSE={}'.format(OnOff[False])
      ],
      cwd=kokkos_kernels_build_dir)
    returncode = subprocess.call(['make', '-j', '4'], cwd=kokkos_kernels_build_dir)
    returncode = subprocess.call(['make', '-j', '4', 'install'], cwd=kokkos_kernels_build_dir)
  else:
    print('*'*79)
    print('cmake was not found')
    print('Skipping builds of kokkos and kokkos-kernels')
    print('*'*79)

  # =============================================================================
  print('Getting environment variables')
  print('-'*79)
  kokkos_cxxflags = subprocess.check_output(
    ['make', '-f', 'Makefile.kokkos', 'print-cxx-flags'],
  cwd=os.environ['KOKKOS_PATH'])
  kokkos_cxxflags = kokkos_cxxflags.split(b'\n')
  kokkos_cxxflags = kokkos_cxxflags[1].decode('utf8')
  os.environ['KOKKOS_CXXFLAGS'] = kokkos_cxxflags
  kokkos_cxxflags = kokkos_cxxflags.split()
  if kokkos_cxxflags and kokkos_cxxflags[0] == 'echo':
    kokkos_cxxflags = [f.strip('"') for f in kokkos_cxxflags[1:]]
  print('KOKKOS_CXXFLAGS:', kokkos_cxxflags)
  print('='*79)


  # reset CXX and LD_LIBRARY_PATH
  if original_ld_path is not None:
    os.environ['LD_LIBRARY_PATH'] = original_ld_path
  if original_cxx is not None:
    print("RESTORE CXX")
    print('   old:', os.environ['CXX'])
    print('   new:', original_cxx)
    os.environ['CXX'] = original_cxx
