from datetime import datetime, timedelta
from getpass import getuser
import json
import os
import click

from ... import colors as c
from .hk_base import HKSpiBase
from .hk_ft232h import HKSpiFT232H
from .hk_stm32 import HKSpiSTM32
from .hk_pico import HKSpiPico

ALL_DRIVERS: 'list[HKSpiBase]' = [
  HKSpiSTM32,
  HKSpiPico,
  HKSpiFT232H,
]

CLAIM_PATH = '/tmp/chipforge_device_claims'

def pass_device_with_options(fn):
  @click.option('--serial', envvar='CHIPFORGE_DEVICE_SERIAL', metavar='SERIAL_NUMBER', help='The serial number of the device to use')
  @click.option('--path', envvar='CHIPFORGE_DEVICE_PATH', metavar="/dev/* or ftdi://", help='The path or URL of the device to use')
  @click.option('--fpga', envvar='CHIPFORGE_DEVICE_FPGA', is_flag=True, help='Select the first available FPGA device')
  @click.option('--pmod', envvar='CHIPFORGE_DEVICE_PMOD', is_flag=True, help='Select the first available fabricated device with PMOD connectors (green)')
  @click.option('--fab', envvar='CHIPFORGE_DEVICE_FAB', is_flag=True, help='Select the first available fabricated device (black or green)')
  @click.option('--override', envvar='CHIPFORGE_DEVICE_OVERRIDE', hidden=True, is_flag=True, help='Override any existing claims to the device')
  def wrapped_fn(serial, path, fpga: bool, pmod: bool, fab: bool, override: bool, *args, **kwargs):
    if fpga + pmod + fab > 1:
      raise click.BadOptionUsage('fpga', 'Only one of [--fpga, --pmod, or --fab] are allowed')
    
    if fpga:
      drivers = [HKSpiSTM32]
      click.echo(c.info('> Searching for FPGA devices (white board)'))
    elif pmod:
      drivers = [HKSpiPico]
      click.echo(c.info('> Searching for PMOD devices (green board)'))
    elif fab:
      drivers = [HKSpiPico, HKSpiFT232H]
      click.echo(c.info('> Searching for all fabricated devices (green & black board)'))
    else:
      drivers = ALL_DRIVERS

    with HKSpi(drivers=drivers, serial=serial, path=path, override=override) as hk:
      return fn(hk, *args, **kwargs)
  wrapped_fn.__name__ = fn.__name__
  wrapped_fn.__doc__ = fn.__doc__
    
  return wrapped_fn

class Claims:
  class Claim:
    def __init__(self, path: str, owner: str, expires: datetime):
      self.path = path
      self.owner = owner
      self.expires = expires
    
  def __init__(self, path: str=CLAIM_PATH):
    self.path = path

  def __enter__(self):
    self.modified = False
    if os.path.exists(self.path):
      with open(self.path) as file:
        try:
          raw_claims = json.load(file)
          self.claims = [self.Claim(raw_claim['path'], raw_claim['owner'], datetime.strptime(raw_claim['expires'], '%Y-%m-%dT%H:%M:%S')) for raw_claim in raw_claims]
        except json.JSONDecodeError:
          self.claims = []
      
      # Clean out expired entries
      modified = False
      for claim in self.claims:
        if claim.expires < datetime.now():
          self.claims.remove(claim)
          modified = True
      if modified:
        self.write_back()
    else:
      self.claims = []
    return self

  def write_back(self):
    raw_claims = [{'path': claim.path, 'owner': claim.owner, 'expires': claim.expires.strftime('%Y-%m-%dT%H:%M:%S')} for claim in self.claims]
    with open(self.path, 'w+') as file:
      try:
        os.chmod(self.path, 0o777)
      except:
        pass
      json.dump(raw_claims, file)

  def claim(self, device: HKSpiBase, fail_soft=False, override=False):
    me = getuser()
    expires = datetime.now() + timedelta(minutes=5)
    for claim in self.claims:
      if claim.path == device.device:
        if override or claim.owner == me:
          claim.expires = expires
          self.write_back()
          return True
        else:
          if not fail_soft:
            click.echo(c.error())
            click.echo(c.error(f'Device {c.code(device)} is already claimed by {c.code(claim.owner)}'))
            click.echo(c.error())
            exit(1)
          return False
    else:
      self.claims.append(self.Claim(device.device, me, expires))
      self.write_back()
      return True
    
  def release_mine(self):
    me = getuser()
    modified = False
    for claim in self.claims:
      if claim.owner == me:
        self.claims.remove(claim)
        modified = True
    if modified:
      self.write_back()
  
  def release_all(self):
    self.claims = []
    self.write_back()

  def is_mine(self, device: HKSpiBase):
    me = getuser()
    for claim in self.claims:
      if claim.path == device.device:
        if claim.owner == me:
          return True
    return False

  def is_unclaimed(self, device: HKSpiBase):
    for claim in self.claims:
      if claim.path == device.device:
        return False
    return True

  def get_owner_string(self, device: HKSpiBase):
    for claim in self.claims:
      if claim.path == device.device:
        return c.red(f'(In use by {c.code(claim.owner)})')
    return c.success('(available)')

  def __exit__(self, *_):
    pass

def HKSpi(drivers:'list[HKSpiBase]'=ALL_DRIVERS, serial=None, path=None, override=False) -> HKSpiBase: 
  devices = [device for driver in drivers for device in driver.list_devices()]

  with Claims() as claims:
    if len(devices) == 0:
      click.echo(c.error())
      click.echo(c.error('No devices found'))
      click.echo(c.error())
      exit(1)

    matched_devices = []
    for device in devices:
      if device.matches(serial=serial, device=path):
        if claims.is_mine(device):
          click.echo(c.info(f'> Found existing claim for {c.code(device.device)}'))
          matched_devices = [device]
          break
        if override or claims.is_unclaimed(device):
          matched_devices.append(device)

    if len(matched_devices) == 0:
      click.echo(c.error())
      if serial != None and path != None:
        click.echo(c.error(f'No device matching serial number {c.code(serial)} and device path {c.code(path)} found'))
      elif serial != None:
        click.echo(c.error(f'No device matching serial number {c.code(serial)} found'))
      elif path != None:
        click.echo(c.error(f'No device matching device path {c.code(path)} found'))

      click.echo(c.error(f'All devices:'))
      for i, device in enumerate(devices):
        click.echo(c.error(c.default(f'  {device} {claims.get_owner_string(device)}')))
      click.echo(c.error())
      exit(1)

    if len(matched_devices) == 1:
      claims.claim(matched_devices[0], override=override)
      click.echo(c.success(f'>> Using device {c.default(matched_devices[0])}\n'))
      return matched_devices[0]
    
    click.echo('\n'.join(f'  {i}) {d} {claims.get_owner_string(d)}' for i, d in enumerate(matched_devices)))
    index = click.prompt(c.info('> Which device should be used? (#)'), prompt_suffix='', type=click.IntRange(0, len(matched_devices)-1))

    claims.claim(matched_devices[index], override=override)
    click.echo(c.success(f'>> Using device {c.default(matched_devices[index])}\n'))
    return matched_devices[index]