import { RawGreeksDataMap, GREEK_IDX } from "types";
import { VolatilityData, VolatilityDataRow, OptionLegType, ModelParameters, ProfitLossWithDomain, ProfitLossData, OptionLeg } from "types/optionsStrat";
import BlackScholes from "./BlackScholes";

export function getStrikePricesForFirstTimestamp(
  currentGreeksObj: any,
): number[] | undefined {
  if (!currentGreeksObj) {
    return undefined;
  }

  const timestamps = Object.keys(currentGreeksObj);
  if (timestamps.length === 0) {
    return undefined;
  }

  const firstTimestamp = timestamps[0];
  const firstTimestampObj = currentGreeksObj[firstTimestamp];
  const strikePrices = Object.keys(firstTimestampObj).map(Number);
  return strikePrices;
}

// Return DTE, price, volatility assuming the first timestamp is 0 DTE after getFilteredGreeksData
export function extractVolatilityData(data: RawGreeksDataMap): VolatilityData[] {
  const result: VolatilityData[] = [];

  for (const tte in data) {
    const priceLevel = data[tte];
    const timestampResult: VolatilityData = {
      tte: Number(tte),
      strikePrices: {},
    };

    for (const price in priceLevel) {
      const vol = priceLevel[price][0][GREEK_IDX.VOLATILITY];
      timestampResult.strikePrices[Number(price)] = vol;
    }

    result.push(timestampResult);
  }

  return result;
}

export function interpolateVolatilityRow(data: VolatilityDataRow[], prices: number[]): number[] {
  const result: number[] = new Array(prices.length).fill(0);
  let dataIndex = 0;

  for (let i = 0; i < prices.length; i++) {
    while (dataIndex + 1 < data.length && data[dataIndex + 1].strike <= prices[i]) {
      dataIndex++;
    }

    if (dataIndex + 1 >= data.length) {
      // clamp right
      result[i] = data[data.length - 1].volatility;
    } else if (dataIndex === 0 && prices[i] < data[0].strike) {
      // clamp left
      result[i] = data[0].volatility;
    } else {
      // interpolate
      const x0 = data[dataIndex].strike;
      const x1 = data[dataIndex + 1].strike;
      const y0 = data[dataIndex].volatility;
      const y1 = data[dataIndex + 1].volatility;
      result[i] = y0 + (y1 - y0) * (prices[i] - x0) / (x1 - x0);
    }
  }

  return result;
}

export function interpolateVolatilityTable(data: VolatilityData[], prices: number[], dte: number): number[] {
  const r0 = Math.floor(dte);
  const r1 = r0 + 1;

  if (data.length === 0) {
    return prices.map(() => 0);
  } else if (r0 < 0) {
    const vol0 = data[0];
    const volRow: VolatilityDataRow[] = Object.keys(vol0.strikePrices).map((strike) => ({
      strike: Number(strike),
      volatility: vol0.strikePrices[Number(strike)],
    }));
    return interpolateVolatilityRow(volRow, prices);
  } else if (r1 >= data.length) {
    const vol0 = data[data.length - 1];
    const volRow: VolatilityDataRow[] = Object.keys(vol0.strikePrices).map((strike) => ({
      strike: Number(strike),
      volatility: vol0.strikePrices[Number(strike)],
    }));
    return interpolateVolatilityRow(volRow, prices);
  }

  const vol0 = data[r0];
  const vol1 = data[r1];

  const volRow0: VolatilityDataRow[] = Object.keys(vol0.strikePrices).map((strike) => ({
    strike: Number(strike),
    volatility: vol0.strikePrices[Number(strike)],
  }));

  const volRow1: VolatilityDataRow[] = Object.keys(vol1.strikePrices).map((strike) => ({
    strike: Number(strike),
    volatility: vol1.strikePrices[Number(strike)],
  }));

  const w1 = dte - r0;
  const w0 = 1 - w1;

  const vol0Interpolated = interpolateVolatilityRow(volRow0, prices);
  const vol1Interpolated = interpolateVolatilityRow(volRow1, prices);

  const result: number[] = vol0Interpolated.map((v, i) => v * w0 + vol1Interpolated[i] * w1);

  return result;
}

export function generatePrices(
  minValue: number,
  maxValue: number,
  numSteps: number = 1_000,
): number[] {
  // Calculate minimum required steps to keep stepSize >= 0.01
  const minRequiredSteps = Math.ceil((maxValue - minValue) / 0.01);
  const adjustedNumSteps = Math.min(numSteps, minRequiredSteps);

  const stepSize = (maxValue - minValue) / (adjustedNumSteps - 1);
  return Array.from({ length: adjustedNumSteps }, (_, i) => minValue + i * stepSize);
}

export function isCall(type: OptionLegType): boolean {
  return [OptionLegType.SHORT_CALL, OptionLegType.LONG_CALL].includes(type);
}

export function isLong(type: OptionLegType): boolean {
  return [OptionLegType.LONG_PUT, OptionLegType.LONG_CALL].includes(type);
}

export function calculateOptionLegProfit(
  { legType, strikePrice, shares, premium }: OptionLeg,
  currentStockPrice: number,
  modelParameters: ModelParameters,
  statisticalVolatility: number
): number {
  const simplePrice = isCall(legType)
    ? Math.max(0, currentStockPrice - strikePrice)
    : Math.max(0, strikePrice - currentStockPrice);

  const bsm = new BlackScholes({
    isCall: isCall(legType),
    currentPrice: currentStockPrice,
    strike: strikePrice,
    riskFreeRate: modelParameters.riskFreeRate,
    yte: modelParameters.daysToExpiration / 365.0,
    iVol: modelParameters.volatility ?? statisticalVolatility,
  });

  const price = modelParameters.daysToExpiration > 0 ? bsm.price() : simplePrice;
  const profit = isLong(legType)
    ? price * shares - premium
    : premium - price * shares;

  return profit;
}

export function calculateProfitLoss(optionLegs: OptionLeg[], volData: VolatilityData[], modelParameters: ModelParameters, priceDomain: number[] | null): ProfitLossWithDomain {
  if (optionLegs.length === 0) {
    return {
      profitLossData: [],
      domain: [0, 0],
    };
  }

  const maxStrikePrice = Math.max(
    ...optionLegs.map((option) => option.strikePrice),
  );

  const prices = priceDomain ? generatePrices(priceDomain[0], priceDomain[1]) : generatePrices(0, 2 * maxStrikePrice);
  const priceVol = interpolateVolatilityTable(volData, prices, modelParameters.daysToExpiration);

  const allProfitLossData: ProfitLossData[] = prices.map((price, index) => {
    let profit: number = 0.0;
    optionLegs.forEach((optionLeg) => {
      const profitLeg = calculateOptionLegProfit(optionLeg, price, modelParameters, priceVol[index]);
      profit += profitLeg;
    });

    return {
      price,
      profit,
      volatility: modelParameters.volatility ?? priceVol[index],
      profitBounds: [profit * 0.8, profit * 1.2]
    };
  });

  // Domain is defined as the region where the PnL line is profit->loss or loss->profit.
  const getProfitDomain = (profitLossData: ProfitLossData[]): number[] => {
    let minDomain: number | null = null;
    let maxDomain: number | null = null;

    for (let i = 0; i < profitLossData.length - 1; i++) {
      if (
        (profitLossData[i].profit >= 0 && profitLossData[i + 1].profit < 0) ||
        (profitLossData[i].profit < 0 && profitLossData[i + 1].profit >= 0)
      ) {
        if (minDomain === null) {
          minDomain = profitLossData[i].price;
        } else {
          maxDomain = profitLossData[i + 1].price;
        }
      }
    }

    if (minDomain === null || maxDomain === null) {
      return [
        0,
        profitLossData.length > 0 ? profitLossData[profitLossData.length - 1].price : 0,
      ];
    }

    let domainWindow = maxDomain - minDomain;
    domainWindow = Math.max(1, domainWindow * 0.15); // MINIMUM_CHART_X_DOMAIN placeholder
    minDomain -= domainWindow;
    maxDomain += domainWindow;

    minDomain = Math.floor(minDomain);
    maxDomain = Math.ceil(maxDomain);

    if (minDomain < 0) {
      minDomain = 0;
    }

    return [minDomain, maxDomain];
  };

  const [minPriceDomain, maxPriceDomain] = getProfitDomain(allProfitLossData);
  const chartProfitLossData = allProfitLossData.filter(
    (data) => data.price >= minPriceDomain && data.price <= maxPriceDomain,
  );

  return {
    profitLossData: chartProfitLossData,
    domain: [minPriceDomain, maxPriceDomain],
  };
}