krfricke / arima

ARIMA modelling for Rust
Apache License 2.0
33 stars 6 forks source link

Example feeding it OHLC candles? #3

Closed brandonros closed 1 year ago

brandonros commented 1 year ago
use chrono::{DateTime, Datelike, TimeZone, Weekday};
use chrono_tz::{Tz, US};
use serde::Deserialize;

#[derive(Debug)]
enum Color {
  Green,
  Red
}

#[derive(Debug, PartialEq)]
enum MarketSessionType {
  None,
  Pre,
  Regular,
  Post,
}

#[derive(Deserialize)]
struct Candle {
  pub start_timestamp: i64,
  pub end_timestamp: i64,
  pub open: f64,
  pub high: f64,
  pub low: f64,
  pub close: f64,
  pub volume: i64,
}

fn datetime_from_timestamp(timestamp: i64) -> DateTime<Tz> {
  let naive = chrono::NaiveDateTime::from_timestamp_opt(timestamp, 0).unwrap();
  return US::Eastern.from_utc_datetime(&naive);
}

fn get_regular_market_session_start_and_end(timestamp: i64) -> (DateTime<Tz>, DateTime<Tz>) {
    let eastern_now = datetime_from_timestamp(timestamp);
    let year = eastern_now.year();
    let month = eastern_now.month();
    let day = eastern_now.day();
    let start = US::Eastern.with_ymd_and_hms(year, month, day, 9, 30, 0).unwrap(); // 9:30:00am
    let end = US::Eastern.with_ymd_and_hms(year, month, day, 15, 59, 59).unwrap(); // 3:59:59pm
    return (start, end);
  }

  fn determine_session_type(timestamp: i64) -> MarketSessionType {
    let eastern_now = datetime_from_timestamp(timestamp);
    // short circuit on weekends
    let weekday = eastern_now.weekday();
    let is_weekend = weekday == Weekday::Sat || weekday == Weekday::Sun;
    if is_weekend {
      return MarketSessionType::None;
    }
    // short circuit on holidays
    let holidays_2022 = vec![
      "2022-01-17 00:00:00", // martin luther king jr day
      "2022-02-21 00:00:00", // preisdent's day
      "2022-04-15 00:00:00", // good friday
      "2022-05-30 00:00:00", // memorial day
      "2022-06-20 00:00:00", // juneteenth
      "2022-07-04 00:00:00", // independence day
      "2022-09-05 00:00:00", // labor day
      "2022-11-24 00:00:00", // day before thanksgiving
      "2022-11-25 00:00:00", // day after thanksgiving (closes at 1pm)?
      "2022-12-26 00:00:00", // day after christmas
    ];
    let holidays_2023 = vec![
      "2023-01-02 00:00:00", // new year's day
      "2023-01-16 00:00:00", // martin luther king jr day
      "2023-02-20 00:00:00", // preisdent's day
      "2023-04-07 00:00:00", // good friday
      "2023-05-29 00:00:00", // memorial day
      "2023-06-19 00:00:00", // juneteenth
      "2023-07-04 00:00:00", // independence day
      "2023-09-04 00:00:00", // labor day
      "2023-11-23 00:00:00", // thanksgiving day
      "2023-11-24 00:00:00", // day after thanksgiving (closes at 1pm)?
      "2023-12-25 00:00:00", // christmas
    ];
    let formatted_eastern_now = eastern_now.format("%Y-%m-%d 00:00:00").to_string();
    let is_2022_holiday = holidays_2022.iter().any(|&holiday| holiday == formatted_eastern_now);
    let is_2023_holiday = holidays_2023.iter().any(|&holiday| holiday == formatted_eastern_now);
    let is_holiday = is_2022_holiday || is_2023_holiday;
    if is_holiday {
      return MarketSessionType::None;
    }
    // check pre/regular/post
    let year = eastern_now.year();
    let month = eastern_now.month();
    let day = eastern_now.day();
    // premarket: 4am -> 9:29:59am
    let pre_market_start = US::Eastern.with_ymd_and_hms(year, month, day, 4, 0, 0).unwrap();
    let pre_market_end = US::Eastern.with_ymd_and_hms(year, month, day, 9, 29, 59).unwrap();
    let seconds_before_pre_market = eastern_now.signed_duration_since(pre_market_start).num_seconds();
    let seconds_after_pre_market = eastern_now.signed_duration_since(pre_market_end).num_seconds();
    let is_before_pre_market = seconds_before_pre_market < 0;
    let is_after_pre_market = seconds_after_pre_market >= 0;
    let is_during_pre_market = is_before_pre_market == false && is_after_pre_market == false;
    // regular: 9:30am -> 3:59:59pm
    let regular_market_start = US::Eastern.with_ymd_and_hms(year, month, day, 9, 30, 0).unwrap();
    let regular_market_end = US::Eastern.with_ymd_and_hms(year, month, day, 15, 59, 59).unwrap();
    let seconds_before_regular_market = eastern_now.signed_duration_since(regular_market_start).num_seconds();
    let seconds_after_regular_market = eastern_now.signed_duration_since(regular_market_end).num_seconds();
    let is_before_regular_market = seconds_before_regular_market < 0;
    let is_after_regular_market = seconds_after_regular_market >= 0;
    let is_during_regular_market = is_before_regular_market == false && is_after_regular_market == false;
    // aftermarket: 4:00pm -> 7:59:59pm
    let after_market_start = US::Eastern.with_ymd_and_hms(year, month, day, 16, 0, 0).unwrap();
    let after_market_end = US::Eastern.with_ymd_and_hms(year, month, day, 19, 59, 59).unwrap();
    let seconds_before_after_market = eastern_now.signed_duration_since(after_market_start).num_seconds();
    let seconds_after_after_market = eastern_now.signed_duration_since(after_market_end).num_seconds();
    let is_before_after_market = seconds_before_after_market < 0;
    let is_after_after_market = seconds_after_after_market >= 0;
    let is_during_after_market = is_before_after_market == false && is_after_after_market == false;
    if is_during_pre_market {
      return MarketSessionType::Pre;
    } else if is_during_regular_market {
      return MarketSessionType::Regular;
    } else if is_during_after_market {
      return MarketSessionType::Post;
    } else {
      return MarketSessionType::None;
    }
  }

fn read_records_from_csv<T>(filename: &str) -> Vec<T>
where
  T: for<'de> Deserialize<'de>,
{
  let mut records = vec![];
  let file = std::fs::File::open(filename).unwrap();
  let mut csv_reader = csv::ReaderBuilder::new().has_headers(true).from_reader(file);
  for result in csv_reader.deserialize() {
    let record: T = result.unwrap();
    records.push(record);
  }
  return records;
}

fn main() {
  let resolution = 1;
  let candles_filename = format!("./candles-{resolution}.csv");
  let candles = read_records_from_csv::<Candle>(&candles_filename);
  let mut previous_close = 0.0;
  println!("grouping_key,start_timestamp,session_type,open,high,low,close,volume,previous_close,close_difference_percent,close_difference_percent_abs,body_size,body_size_abs,body_size_percentage,range,range_percentage,color");
  for candle in &candles {
    let session_type = determine_session_type(candle.start_timestamp);
    let (regular_session_start, _regular_session_end) = get_regular_market_session_start_and_end(candle.start_timestamp);
    let grouping_key = regular_session_start.timestamp();
    let close_difference_percent = if previous_close == 0.0 { 0.0 } else { (candle.close - previous_close) / previous_close };
    let close_difference_percent_abs = close_difference_percent.abs();
    let start_timestamp = candle.start_timestamp;
    let open = candle.open;
    let high = candle.high;
    let low = candle.low;
    let close = candle.close;
    let volume = candle.volume;
    let body_size = candle.close - candle.open;
    let body_size_abs = body_size.abs();
    let body_size_percentage = (candle.close - candle.open) / candle.open;
    let range = candle.high - candle.low;
    let range_percentage = (candle.high - candle.low) / candle.low;
    let color = if candle.close >= candle.open { Color::Green } else { Color::Red };
    if session_type == MarketSessionType::Regular {
      println!("{grouping_key},{start_timestamp},{session_type:?},{open},{high},{low},{close},{volume},{previous_close:.6},{close_difference_percent:.6},{close_difference_percent_abs:.6},{body_size:.6},{body_size_abs:.6},{body_size_percentage:.6},{range:.6},{range_percentage:.6},{color:?}");
    }
    previous_close = candle.close;
  }
}

I want to feed it close_difference_percent I'm pretty sure, then get it to spit out an estimate. What are your thoughts?

brandonros commented 1 year ago
// Add necessary imports for working with OHLC data
use std::iter::Iterator;

// Define a struct for OHLC candles
#[derive(Debug)]
struct OhlcCandle {
    open: f64,
    high: f64,
    low: f64,
    close: f64,
}

// Function to calculate HLC3 values from OHLC candles
fn hlc3(candles: &[OhlcCandle]) -> Vec<f64> {
    candles.iter().map(|c| (c.high + c.low + c.close) / 3.0).collect()
}

fn main() {
    // ... initialize RNG and ARIMA model parameters as before ...

    // Provide your OHLC candles data
    let ohlc_candles = vec![
        OhlcCandle { open: 100.0, high: 105.0, low: 99.0, close: 104.0 },
        // ... more OHLC candles ...
    ];

    // Calculate HLC3 values from OHLC candles
    let hlc3_values = hlc3(&ohlc_candles);

    // Estimate ARIMA model parameters based on HLC3 values
    // ar - order of AR coefficients (e.g., 2)
    // d - order of differencing (e.g., 0)
    // ma - order of MA coefficients (e.g., 1)
    let ar_order = 2;
    let diff_order = 0;
    let ma_order = 1;
    let coef = estimate::fit(&hlc3_values, ar_order, diff_order, ma_order).unwrap();

    println!("Estimated parameters: {:?}", coef);
}

From ChatGPT. I guess I misunderstood the part of sim. Handling OHLC data means sim rng based data wouldn't be needed.

The next part becomes figuring out (through grid search + mean absolute error I'm guessing) which ar + d + ma parameters work best.

Sorry for the thinking outloud spam. Was hoping it'd help somebody else who might land on this and have a very very very lacking/basic understanding of the math/concepts at play here.