#!/usr/bin/env python3
#
# Copyright(C) 2025 Advanced Micro Devices, Inc. All rights reserved.
#
#   gpurun: Application process launch utility for GPUs.
#           This utility ensures the process will enable either a single
#           GPU or the number specified with -md (multi-device) option.
#           It launches the application binary with either the 'taskset'
#           or 'numactl' utility so the process only runs on CPU cores
#           in the same NUMA domain as the selected GPUs.
#
#           This utility sets environment variable ROCR_VISIBLE_DEVICES
#           to selected GPUs ONLY if it was not already set by the
#           callers environment AND the number of GPUs is not 1.
#
#         Future:
#           This utility also sets environment variable HSA_CU_MASK
#           to control which CUs are available to the process.
#           HSA_CU_MASK is set only when more than one OpenMPI process
#           (rank) will utilize the same GPU and it is not preset.
#           Lastly, it sets env variable OMPX_TARGET_TEAM_PROCS to the
#           number of CUs available to the process after masking.
#
#   $ gpurun -topo
#   Topology     Numa: 0   PageSize: [always] madvise never
#
#   GPU     Node  Affinity       UUID               Cores
#    0        0       0       GPU-b256278bf70405e2    0-23,96-119
#    1        1       1       GPU-a33557394e2c744e    24-47,120-143
#    2        2       2       GPU-4f78640baf57e5f0    48-71,144-167
#    3        3       3       GPU-b66921701d196e10    72-95,168-191

import subprocess
import re
import os
import sys

if sys.version_info < (3, 7):
   print("require minimum python version 3.7 or later")
   sys.exit(0)

noAmdSmi = False

try:
  from amdsmi import *
except ImportError:
   noAmdSmi = True


def get_amd_smi_static_numa():
    """
    get the output of 'amd-smi static --numa' to extract GPU affinity
    and NUMA node information, storing them in arrays indexed by GPU number.
    """
    gpu_affinity = []
    numa_node = []
    hip_uuid = []
    gpu_id = 0

    amdsmi_init()

    try:
       devices = amdsmi_get_processor_handles()
       node_number = 0
       affi_node = 0
       if len(devices) == 0:
          print("No GPUs on machine")
          sys.extit(1)
       for device in devices:
          info = amdsmi_get_gpu_enumeration_info(device)
          node_number = amdsmi_topo_get_numa_node_number(device)
          if debug_numa > 2: print("****");print("gpu_id: ", gpu_id);print("Numa: ",node_number)
          # Ensure lists are large enough to accommodate the GPU ID
          while len(numa_node) <= gpu_id:
             numa_node.append(None)
          numa_node[gpu_id] = node_number
          while len(hip_uuid) <= gpu_id:
             hip_uuid.append(None)
          hip_uuid[gpu_id] = info['hip_uuid']
          if debug_numa > 2: print("hip_id: ", info['hip_id']); print("hip_uuid: ", info['hip_uuid'])

          try:
            affi_node = amdsmi_get_gpu_topo_numa_affinity(device)
            if affi_node == -1: affi_node = node_number
            if debug_numa > 2: print("Affinity: ", affi_node)
          except AmdSmiException as e:
            if debug_numa > 2: print("N/A")

          # Ensure lists are large enough to accommodate the GPU ID
          while len(gpu_affinity) <= gpu_id:
             gpu_affinity.append(None)
          gpu_affinity[gpu_id] = affi_node
          gpu_id += 1
    except AmdSmiException as e:
      printr(f"Error executing amd-smi: {e}")

    if len(gpu_affinity) == 0:
      gpu_affinity.append(None)
      gpu_affinity[0] = 0

    if debug_numa > 2: print("parse_rocm_smi_toponuma:" , gpu_affinity, numa_node, hip_uuid)
    return gpu_affinity, numa_node, hip_uuid

def parse_rocm_smi_toponuma():
    """
    Parses the output of 'rocm-smi --showtoponuma' to extract GPU affinity
    and NUMA node information, storing them in arrays indexed by GPU number.
    """
    try:
        # Execute the rocm-smi command
        UIresult = subprocess.run(['rocm-smi', '--showuniqueid'], capture_output=True, text=True, check=True)
        UIoutput = UIresult.stdout
    except subprocess.CalledProcessError as e:
        print(f"Error executing rocm-smi: {e}")
        return None, None, None
    except FileNotFoundError:
        print("Error: 'rocm-smi' command not found. Ensure ROCm is installed and in your PATH.")
        return None, None, None

    hip_uuid = []
    patternUI = re.compile(r"GPU\[(\d+)\]\s+:\s+Unique\s+ID:\s+0x([0-9a-fA-F]+)")
    for line in UIoutput.splitlines():
        match = patternUI.search(line)
        if match:
            gpu_id = int(match.group(1))
            UUID = match.group(2)
            while len(hip_uuid) <= gpu_id:
               hip_uuid.append(None)
            hip_uuid[gpu_id] = "GPU-"+UUID

    try:
        # Execute the rocm-smi command
        result = subprocess.run(['rocm-smi', '--showtoponuma'], capture_output=True, text=True, check=True)
        output = result.stdout
    except subprocess.CalledProcessError as e:
        print(f"Error executing rocm-smi: {e}")
        return None, None
    except FileNotFoundError:
        print("Error: 'rocm-smi' command not found. Ensure ROCm is installed and in your PATH.")
        return None, None, None

    gpu_affinity = []
    numa_node = []

    # Regex to find lines containing GPU information (e.g., "GPU[0-9]: Affinity: [0-9]+, Node: [0-9]+")
    patternAffy = re.compile(r"GPU\[(\d+)\]\s+:\s+\(Topology\) Numa Affinity:\s+(\d+)")
    patternErrA = re.compile(r"get_numa_affinity_topology, Not supported on the given system")
    patternNode = re.compile(r"GPU\[(\d+)\]\s+:\s+\(Topology\) Numa Node:\s+(\d+)")

    for line in output.splitlines():
        match = patternAffy.search(line)
        if match:
            gpu_id = int(match.group(1))
            affinity = int(match.group(2))
            # Ensure lists are large enough to accommodate the GPU ID
            while len(gpu_affinity) <= gpu_id:
                gpu_affinity.append(None)

            gpu_affinity[gpu_id] = affinity
        match = patternNode.search(line)
        if match:
            gpu_id = int(match.group(1))
            node = int(match.group(2))

            # Ensure lists are large enough to accommodate the GPU ID
            while len(numa_node) <= gpu_id:
                numa_node.append(None)

            numa_node[gpu_id] = node
        #cpx tpx etc are missing affinity info, fix it here
        match = patternErrA.search(line)
        if match:
            # Ensure lists are large enough to accommodate the GPU ID
            while len(gpu_affinity) <= gpu_id:
                gpu_affinity.append(None)

            #use previous gpu_affinity
            gpu_affinity[gpu_id] = affinity

    if len(gpu_affinity) == 0:
      gpu_affinity.append(None)
      gpu_affinity[0] = 0

    if debug_numa > 2: print("parse_rocm_smi_toponuma:" , gpu_affinity, numa_node, hip_uuid)
    return gpu_affinity, numa_node, hip_uuid


def parse_lscpu_numa():
# get lscpu numa info
#    NUMA node0 CPU(s):                       0-7
#    NUMA node1 CPU(s):                       8-15
    try:
        # Execute the rocm-smi command
        result = subprocess.run(['lscpu'], capture_output=True, text=True, check=True)
        output = result.stdout
    except subprocess.CalledProcessError as e:
        print(f"Error executing lscpu: {e}")
        return None
    except FileNotFoundError:
        print("Error: 'lscpu' command not found.")
        return None

    numa_cpus = []
    patternLSCPU = re.compile(r"NUMA node(\d+) CPU\(s\):\s+([\d,-]+)")

    if debug_numa > 2:print("NUMA CPUs:")
    for line in output.splitlines():
        match = patternLSCPU.search(line)
        if match:
            numa_id = int(match.group(1))
            cpus = match.group(2)
            if debug_numa > 2:print("  numa cores:", numa_id, cpus)
            # Ensure lists are large enough to accommodate the GPU ID
            while len(numa_cpus) <= numa_id:
                numa_cpus.append(None)
            numa_cpus[numa_id] = cpus
    return numa_cpus

def check_numactl_exists():
    try:
        subprocess.run(['numactl', '--version'], check=True, 
                       stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
        return True
    # numactl command not found in PATH
    except FileNotFoundError: return False
    except subprocess.CalledProcessError: return True
    except Exception as e:
        print(f"An unexpected error occurred: {e}")
        return False

def check_taskset_exists():
    try:
        subprocess.run(['taskset', '--version'], check=True, 
                       stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
        return True
    # taskset command not found in PATH
    except FileNotFoundError: return False
    except subprocess.CalledProcessError: return True
    except Exception as e:
        print(f"An unexpected error occurred: {e}")
        return False

def helpExit(exCode):
    if exCode == 1: print("Error: nothing to bind")
    print("Usage: gpurun [gpurun_options] Program and options")
    print("  -h --help : display help test")
    print("  -v        : display gpurun command")
    print("  -vv       : display additional debug info")
    print("  -vvv      : display more debug info")
    print("  -dryrun   : do not run bindings")
    print("  -taskset  : use taskset for binding")
    print("  -numatcl  : use numactl for binding [default]")
    print("  -l        : use numactl --localalloc")
    print("  -m        : use numactl --membind[default]")
    print("  -md       : Set number of desired devices for multi-device mode, default=1")
    print("  -nr       : use numactl ROCR_VISIBLE_DEVICES")
    print("  -nm       : use numactl OMPI_COMM_WORLD_LOCAL_RANK")
    print("  -topo     : display the topology and exit")
    print("  -rocmsmi  : force use of rocm-smi rather than amd-smi")
    print("  -amdsmi   : force use of amd-smi rather than rocm-smi")
    print("  -nomask   : sets GPURUN_MASK_POLICY to nomask : not yet implemented")
    print("  --version : Print version of gpurun and exit")
    print("")
    print("Supported environment variables")
    print("  GPURUN_DEVICE_BIAS    Device# to start with [default 0]")
    print("  GPURUN_BYPASS         pass through, no bindings")
    print("")

    sys.exit(exCode)
    # still to do
    #  -m   use numactl membind to CPUs in same NUMA domain. Note: Allocation
    #       fails when not enough memory available on these nodes.
    #  -l   use numactl localalloc to CPUs in same NUMA domain. Note: If
    #       memory cannot be allocated, alloc falls back to other nodes.
    #  support GPU-xxxxxxxx

def processArgs():
    sysPos=1
    debug_numa=0
    use_taskset=False
    use_numactl=True
    use_nobind=False
    use_nr=False
    use_nm=False
    use_md=False
    use_localalloc=False
    use_membind=False
    use_rocmsmi=False
    use_amdsmi=False
    md_count=1
    use_nomask_policy=False
    dump_topo=False
    dry_run=False
    skip_args = ["-s", "-q" ]
    # loop over bind arguments
    while True:
      if len(sys.argv[sysPos:]) == 0:
        if dump_topo: break
        helpExit(1)
      if   sys.argv[sysPos] == "-v": debug_numa=1
      elif sys.argv[sysPos] == "-vv":  debug_numa=2
      elif sys.argv[sysPos] == "-vvv": debug_numa=3
      elif sys.argv[sysPos] in ["-h", "-help", "--help"]: helpExit(0)
      elif sys.argv[sysPos] == "--version": print("Version: 22.0.0"); sys.exit(0)
      elif sys.argv[sysPos] == "-dryrun": dry_run=True
      elif sys.argv[sysPos] == "-taskset": use_taskset=True; use_numactl=False; use_nobind=False
      elif sys.argv[sysPos] == "-numactl": use_numactl=True; use_taskset=False; use_nobind=False
      elif sys.argv[sysPos] == "-nobind":  use_nobind=True;  use_taskset=False; use_numactl=False
      elif sys.argv[sysPos] == "-topo":  dump_topo=True
      elif sys.argv[sysPos] == "-nr":  use_nr=True
      elif sys.argv[sysPos] == "-nm":  use_nm=True
      elif sys.argv[sysPos] == "-m":  use_membind=True
      elif sys.argv[sysPos] == "-l":  use_localalloc=True
      elif sys.argv[sysPos] == "-nomask":  use_nomask_policy=True
      elif sys.argv[sysPos] == "-rocmsmi":  use_rocmsmi=True; use_amdsmi=False
      elif sys.argv[sysPos] == "-amdsmi":  use_amdsmi=True; use_rocmsmi=False
      elif sys.argv[sysPos] == "-md":
          use_md=True
          if sys.argv[sysPos+1].isdigit():
              md_count=int(sys.argv[sysPos+1])
              sysPos += 1
      #to be implimented GPURUN options
      elif sys.argv[sysPos] in skip_args: skipped_args=True
      else: break
      sysPos += 1

    return sysPos, debug_numa, use_taskset, use_numactl, use_nobind, dry_run, use_md, md_count, use_nr, use_nm, dump_topo, use_rocmsmi, use_amdsmi, use_membind, use_localalloc

def dumpTopology(affinity_data, node_data, hip_uuid, numa_cpus):
    numaStat="<unknown>"
    pageSize="<unknown>"
    with open('/proc/sys/kernel/numa_balancing', 'r') as f: numaStat = f.read()
    with open('/sys/kernel/mm/transparent_hugepage/enabled', 'r') as f: pageSize = f.read()
    Tb="\t"
    print("Topology     numa_balancing: "+numaStat.strip()+"   PageSize: "+pageSize.strip()+"\n\nGPU     Node  Affinity       UUID               Cores")
    for i in range(len(node_data)):
      print(i, Tb, node_data[i], Tb, affinity_data[i], Tb, hip_uuid[i], Tb, numa_cpus[affinity_data[i]])
    sys.exit(0)

if __name__ == "__main__":
    sysPos, debug_numa, use_taskset, use_numactl, use_nobind, dry_run, use_md, md_count, use_nr, use_nm, dump_topo, use_rocmsmi, use_amdsmi, use_membind, use_localalloc = processArgs()
    # support override by envvar
    gpurun_bypass = int(os.environ.get('GPURUN_BYPASS', '0'))
    my_env = os.environ.copy()
    if gpurun_bypass:
       program_to_run = [ ]
       program_to_run.extend(sys.argv[sysPos:])
       result = subprocess.run(program_to_run, env=my_env, capture_output=False, text=False, check=False)
       sys.exit(0)

    #check for numactl and taskset
    has_numactl = check_numactl_exists()
    has_taskset = check_taskset_exists()

    #get topo info
    if use_taskset or dump_topo: numa_cpus = parse_lscpu_numa()
    if use_amdsmi:
       affinity_data, node_data, hip_uuid = get_amd_smi_static_numa()
    elif noAmdSmi or use_rocmsmi:
       affinity_data, node_data, hip_uuid = parse_rocm_smi_toponuma()
    else:
       affinity_data, node_data, hip_uuid = get_amd_smi_static_numa()

    if debug_numa > 1: print(affinity_data, node_data, hip_uuid)
    if dump_topo: dumpTopology(affinity_data, node_data, hip_uuid, numa_cpus)

    numGpus = len(node_data)
    rocrVisDev = int(os.environ.get('ROCR_VISIBLE_DEVICES', '-1'))
    localRank = int(os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK', '0'))
    numRanksLocal = int(os.environ.get('OMPI_COMM_WORLD_LOCAL_SIZE', '1'))
    gpurun_device_bias = int(os.environ.get('GPURUN_DEVICE_BIAS', '0'))

    if rocrVisDev != -1 or use_nr:
       adjRank = rocrVisDev + gpurun_device_bias
    elif use_nm or use_numactl:
       adjRank = (localRank + gpurun_device_bias) % numGpus
    else:
       adjRank=gpurun_device_bias % numGpus
    if debug_numa > 1:
      print("#GPUs ", numGpus, "numRanks", numRanksLocal, "localRank", localRank, "adjRank", adjRank, "RVD", rocrVisDev, "gpurun_device_bias", gpurun_device_bias)
      if debug_numa > 2:
        if affinity_data is not None and node_data is not None:
          print("\nGPU Affinity:")
          for i, affinity in enumerate(affinity_data):
            if affinity is not None:
                print(f"  GPU {i}: Affinity = {affinity}")

          print("\n GPU NUMA Nodes:")
          for i, node in enumerate(node_data):
            if node is not None:
                print(f"  GPU {i}: NUMA Node = {node}")

    my_env = os.environ.copy()
    if use_md:
       my_env["ROCR_VISIBLE_DEVICES"] = "0,1"
    else:
       my_env["ROCR_VISIBLE_DEVICES"] = str(adjRank)
    if use_taskset and has_taskset:
       if use_localalloc or use_membind: print("Warning: taskset does not support localalloc or membind, use numactl")
       program_to_run = [ "taskset", "-c", numa_cpus[node_data[adjRank]]]
    elif use_numactl and has_numactl:
       program_to_run = [ "numactl", "--cpunodebind", str(node_data[adjRank]), "--membind", str(affinity_data[adjRank]) ]
       if use_localalloc: program_to_run += "--localalloc"
    elif use_nobind:
       program_to_run = [ ]
    else:
       program_to_run = [ ]
    program_to_run.extend(sys.argv[sysPos:])
    if debug_numa > 0 or dry_run: print("ROCR_VISIBLE_DEVICES", my_env["ROCR_VISIBLE_DEVICES"], " ", program_to_run)
    if not dry_run: result = subprocess.run(program_to_run, env=my_env, capture_output=False, text=False, check=False)
