In [1]:
#require "core,torch_ext,owl,matplotlib,jupyter.notebook"
In [2]:
open Core
open Torch_ext
open Owl
module T = Tensor
module S = Scalar
module D = Distributions
Out[2]:
module T = Torch_ext.Tensor
Out[2]:
module S = Torch_ext.Scalar
Out[2]:
module D = Torch_ext.Distributions
In [3]:
let bernoulli p = T.bernoulli_float ~p [] |> T.bool_value

let rand () = T.rand [] |> T.float_value

let categorical probs =
  let u = rand () in
  let exception Ret of int in
  try
    ignore (List.foldi probs ~init:0.0 ~f:(fun i acc p -> if Float.(u < acc + p) then raise (Ret i) else Float.(acc + p)) : float) ;
    assert false
  with Ret i -> i

let gamma a b =
  let d = D.gamma (T.f a) (T.f b) in
  d#sample () |> T.float_value
Out[3]:
val bernoulli : Base.float -> Base.bool = <fun>
Out[3]:
val rand : unit -> Base.float = <fun>
Out[3]:
val categorical : Core.Float.t list -> int = <fun>
Out[3]:
val gamma : Base.float -> Base.float -> Base.float = <fun>
In [4]:
type kern =
  | Constant of float
  | Linear of float
  | Squared_exponential of float
  | Periodic of float * float
  | Plus of kern * kern
  | Times of kern * kern
Out[4]:
type kern =
    Constant of float
  | Linear of float
  | Squared_exponential of float
  | Periodic of float * float
  | Plus of kern * kern
  | Times of kern * kern
In [5]:
let rec size (k : kern) : int =
  match k with
  | Constant _ | Linear _ | Squared_exponential _ | Periodic _ ->
    1
  | Plus (left, right) | Times (left, right) ->
    size left + size right + 1
Out[5]:
val size : kern -> int = <fun>
In [6]:
let rec eval_cov_mat (k : kern) (xs : T.t) : T.t =
  match k with
  | Constant param ->
    let n = T.shape1_exn xs in
    T.new_ones ~scale:(S.f param) [n; n]
  | Linear param ->
    let xs_minus_param = T.sub_scalar xs (S.f param) in
    T.outer xs_minus_param ~vec2:xs_minus_param
  | Squared_exponential length_scal ->
    let self_diff =
      let tmp = T.(outer xs ~vec2:(ones_like xs)) in
      T.(tmp - tr tmp)
    in
    T.(exp (T.div_scalar (T.scale_f (self_diff * self_diff) (-0.5)) (S.f length_scal)))
  | Periodic (scal, period) ->
    let freq = Float.(2.0 * pi / period) in
    let abs_self_diff =
      let tmp = T.outer xs ~vec2:(T.ones_like xs) in
      T.(abs (tmp - tr tmp))
    in
    T.(exp (scale_f (square (sin (scale_f abs_self_diff freq))) Float.(-1.0 / scal)))
  | Plus (left, right) ->
    T.(eval_cov_mat left xs + eval_cov_mat right xs)
  | Times (left, right) ->
    T.(eval_cov_mat left xs * eval_cov_mat right xs)
Out[6]:
val eval_cov_mat : kern -> T.t -> T.t = <fun>
In [7]:
let compute_cov_matrix_vectorized (k : kern) (noise : float) (xs : T.t) : T.t =
  let n = T.shape1_exn xs in
  T.(eval_cov_mat k xs + new_eye ~scale:(S.f noise) n)
Out[7]:
val compute_cov_matrix_vectorized : kern -> float -> T.t -> T.t = <fun>
In [8]:
let compute_log_likelihood (k : kern) (noise : float) (xs : T.t) (ys : T.t) : float =
  let mu = T.zeros_like xs in
  let cov = compute_cov_matrix_vectorized k noise xs in
  let scale = T.linalg_cholesky cov in
  let mvn = D.multivariate_normal mu scale in
  mvn#log_prob ys |> T.float_value
Out[8]:
val compute_log_likelihood : kern -> float -> T.t -> T.t -> float = <fun>
In [9]:
let rec covariance_prior () : kern =
  let node_type = categorical [0.2; 0.2; 0.2; 0.2; 0.1; 0.1] in
  match node_type with
  | 0 ->
    Constant (rand ())
  | 1 ->
    Linear (rand ())
  | 2 ->
    Squared_exponential (rand ())
  | 3 ->
    Periodic (rand (), rand ())
  | 4 ->
    Plus (covariance_prior (), covariance_prior ())
  | 5 ->
    Times (covariance_prior (), covariance_prior ())
  | _ ->
    assert false
Out[9]:
val covariance_prior : unit -> kern = <fun>
In [10]:
let rec pick_random_node_unbiased (k : kern) (cur : int) : int * kern =
  match k with
  | Constant _ | Linear _ | Squared_exponential _ | Periodic _ ->
    (cur, k)
  | Plus (left, right) | Times (left, right) -> (
    let sz = size k |> Float.of_int in
    let probs = [Float.(1.0 / sz); Float.(of_int (size left) / sz); Float.(of_int (size right) / sz)] in
    let choice = categorical probs in
    match choice with
    | 0 ->
      (cur, k)
    | 1 ->
      pick_random_node_unbiased left (cur * 2)
    | 2 ->
      pick_random_node_unbiased right ((cur * 2) + 1)
    | _ ->
      assert false )
Out[10]:
val pick_random_node_unbiased : kern -> int -> int * kern = <fun>
In [11]:
let rec replace_subtree (k : kern) (cur : int) ~(to_ : kern) ~(on_ : int) : kern =
  match k with
  | Constant _ | Linear _ | Squared_exponential _ | Periodic _ ->
    if cur = on_ then to_ else k
  | Plus (left, right) ->
    if cur = on_ then to_
    else Plus (replace_subtree left (cur * 2) ~to_ ~on_, replace_subtree right ((cur * 2) + 1) ~to_ ~on_)
  | Times (left, right) ->
    if cur = on_ then to_
    else Times (replace_subtree left (cur * 2) ~to_ ~on_, replace_subtree right ((cur * 2) + 1) ~to_ ~on_)
Out[11]:
val replace_subtree : kern -> int -> to_:kern -> on_:int -> kern = <fun>
In [12]:
let get_alpha_subtree_unbiased (prev : kern) (prop : kern) : float =
  Float.(log (of_int (size prev)) - log (of_int (size prop)))
Out[12]:
val get_alpha_subtree_unbiased : kern -> kern -> float = <fun>
In [13]:
type trace = {cov_k: kern; noise: float; xs: T.t; ys: T.t; log_likelihood: float}
Out[13]:
type trace = {
  cov_k : kern;
  noise : float;
  xs : T.t;
  ys : T.t;
  log_likelihood : float;
}
In [14]:
let mh_resample_subtree_unbiased (prev_trace : trace) : trace =
  let i_delta, k_delta = pick_random_node_unbiased prev_trace.cov_k 1 in
  let subtree = covariance_prior () in
  let cov_k_new = replace_subtree prev_trace.cov_k 1 ~to_:subtree ~on_:i_delta in
  let log_likelihood = compute_log_likelihood cov_k_new prev_trace.noise prev_trace.xs prev_trace.ys in
  let new_trace = {cov_k= cov_k_new; noise= prev_trace.noise; xs= prev_trace.xs; ys= prev_trace.ys; log_likelihood} in
  let alpha_size = get_alpha_subtree_unbiased k_delta subtree in
  let alpha_ll = Float.(new_trace.log_likelihood - prev_trace.log_likelihood) in
  let alpha = Float.(alpha_ll + alpha_size) in
  if Float.(log (rand ()) < alpha) then new_trace else prev_trace
Out[14]:
val mh_resample_subtree_unbiased : trace -> trace = <fun>
In [15]:
let mh_resample_noise (prev_trace : trace) : trace =
  let noise_new = Float.(gamma 1.0 1.0 + 0.01) in
  let log_likelihood = compute_log_likelihood prev_trace.cov_k noise_new prev_trace.xs prev_trace.ys in
  let new_trace = {cov_k= prev_trace.cov_k; noise= noise_new; xs= prev_trace.xs; ys= prev_trace.ys; log_likelihood} in
  let alpha = Float.(new_trace.log_likelihood - prev_trace.log_likelihood) in
  if Float.(log (rand ()) < alpha) then new_trace else prev_trace
Out[15]:
val mh_resample_noise : trace -> trace = <fun>
In [16]:
let initialize_trace (xs : T.t) (ys : T.t) : trace =
  let cov_k = covariance_prior () in
  let noise = Float.(gamma 1.0 1.0 + 0.01) in
  let log_likelihood = compute_log_likelihood cov_k noise xs ys in
  {cov_k; noise; xs; ys; log_likelihood}
Out[16]:
val initialize_trace : T.t -> T.t -> trace = <fun>
In [17]:
let run_mcmc (prev_trace : trace) (iters : int) : trace =
  let new_trace = ref prev_trace in
  for _ = 1 to iters do
    new_trace := mh_resample_subtree_unbiased !new_trace ;
    new_trace := mh_resample_noise !new_trace
  done ;
  !new_trace
Out[17]:
val run_mcmc : trace -> int -> trace = <fun>
In [18]:
let rescale_linear (xs : T.t) (yl : float) (yh : float) : T.t =
  let xl = T.minimum xs |> T.float_value in
  let xh = T.maximum xs |> T.float_value in
  let slope = Float.((yh - yl) / (xh - xl)) in
  let intercept = Float.(yh - (xh * slope)) in
  T.(add_scalar (scale_f xs slope) (S.f intercept))
Out[18]:
val rescale_linear : T.t -> float -> float -> T.t = <fun>
In [19]:
let load_dataset_from_path (path : string) (n_test : int) : (T.t * T.t) * (T.t * T.t) =
  let df = Dataframe.of_csv ~sep:',' ~types:[|"f"; "f"|] path in
  let xs = rescale_linear (Dataframe.get_col_by_name df "x" |> Dataframe.unpack_float_series |> T.of_float1) 0.0 1.0 in
  let ys = rescale_linear (Dataframe.get_col_by_name df "y" |> Dataframe.unpack_float_series |> T.of_float1) (-1.0) 1.0 in
  let n = T.shape1_exn xs in
  let[@warning "-8"] [xs_train; xs_test] = T.split_with_sizes ~split_sizes:[n - n_test; n_test] xs in
  let[@warning "-8"] [ys_train; ys_test] = T.split_with_sizes ~split_sizes:[n - n_test; n_test] ys in
  ((xs_train, ys_train), (xs_test, ys_test))
Out[19]:
val load_dataset_from_path : string -> int -> (T.t * T.t) * (T.t * T.t) =
  <fun>
In [20]:
let (xs_train, ys_train), (xs_test, ys_test) = load_dataset_from_path "airline.csv" 15
2022-11-30 09:32:13.887 WARN : Owl_io.head: ignored exception "Assert_failure src/base/misc/owl_io.ml:138:9"
Out[20]:
val xs_train : T.t = <abstr>
val ys_train : T.t = <abstr>
val xs_test : T.t = <abstr>
val ys_test : T.t = <abstr>
In [21]:
let get_conditional_mu_cov (cov_k : kern) (noise : float) (xs : T.t) (ys : T.t) (new_xs : T.t) : T.t * T.t =
  let n_prev = T.shape1_exn xs in
  let n_new = T.shape1_exn new_xs in
  let cov_matrix = compute_cov_matrix_vectorized cov_k noise (T.cat [xs; new_xs]) |> T.to_float2_exn in
  let cov_matrix11 = Array.init n_prev ~f:(fun i -> Array.init n_prev ~f:(fun j -> cov_matrix.(i).(j))) |> T.of_float2 in
  let cov_matrix22 = Array.init n_new ~f:(fun i -> Array.init n_new ~f:(fun j -> cov_matrix.(n_prev + i).(n_prev + j))) |> T.of_float2 in
  let cov_matrx12 = Array.init n_prev ~f:(fun i -> Array.init n_new ~f:(fun j -> cov_matrix.(i).(n_prev + j))) |> T.of_float2 in
  let cov_matrix21 = Array.init n_new ~f:(fun i -> Array.init n_prev ~f:(fun j -> cov_matrix.(n_prev + i).(j))) |> T.of_float2 in
  let mu1 = T.zeros [n_prev] in
  let mu2 = T.zeros [n_new] in
  let conditional_mu = T.(mu2 + matmul cov_matrix21 (T.linalg_solve cov_matrix11 (ys - mu1))) in
  let conditional_cov_matrix = T.(cov_matrix22 - matmul cov_matrix21 (T.linalg_solve cov_matrix11 cov_matrx12)) in
  let conditional_cov_matrix = T.(scale_f conditional_cov_matrix 0.5 + scale_f (tr conditional_cov_matrix) 0.5) in
  (conditional_mu, conditional_cov_matrix)
Out[21]:
val get_conditional_mu_cov : kern -> float -> T.t -> T.t -> T.t -> T.t * T.t =
  <fun>
In [22]:
let compute_log_likelihood_predictive (cov_k : kern) (noise : float) (xs : T.t) (ys : T.t) (new_xs : T.t) (new_ys : T.t) : float =
  let mu, cov = get_conditional_mu_cov cov_k noise xs ys new_xs in
  let scale = T.linalg_cholesky cov in
  let d = D.multivariate_normal mu scale in
  d#log_prob new_ys |> T.float_value
Out[22]:
val compute_log_likelihood_predictive :
  kern -> float -> T.t -> T.t -> T.t -> T.t -> float = <fun>
In [23]:
let gp_predictive_samples (cov_k : kern) (noise : float) (xs : T.t) (ys : T.t) (new_xs : T.t) (n : int) : T.t list =
  let mu, cov = get_conditional_mu_cov cov_k noise xs ys new_xs in
  let scale = T.linalg_cholesky cov in
  let d = D.multivariate_normal mu scale in
  List.init n ~f:(fun _ -> d#sample ())
Out[23]:
val gp_predictive_samples :
  kern -> float -> T.t -> T.t -> T.t -> int -> T.t list = <fun>
In [24]:
type result =
  { log_likelihood: float
  ; predictions_held_out: T.t list }
Out[24]:
type result = { log_likelihood : float; predictions_held_out : T.t list; }
In [25]:
let infer_and_predict (trace : trace) (iters : int) (xs_train : T.t) (ys_train : T.t) (xs_test : T.t) (ys_test : T.t) (npred_out : int) : trace * result =
  let trace = run_mcmc trace iters in
  let cov_k, noise = (trace.cov_k, trace.noise) in
  let predictions_held_out = gp_predictive_samples cov_k noise xs_train ys_train xs_test npred_out in
  let result =
    { log_likelihood= trace.log_likelihood
    ; predictions_held_out= predictions_held_out }
  in
  (trace, result)
Out[25]:
val infer_and_predict :
  trace -> int -> T.t -> T.t -> T.t -> T.t -> int -> trace * result = <fun>
In [26]:
let run_pipeline (xs_train : T.t) (ys_train : T.t) (xs_test : T.t) (ys_test : T.t) (iters : int) (epochs : int) (npred_held_out : int) (seed : int) : result list =
  let iterations = List.init epochs ~f:(fun _ -> iters) in
  Torch_core.Wrapper.manual_seed seed ;
  let trace = initialize_trace xs_train ys_train in
  let statistics = Queue.create () in
  let last_trace = ref trace in
  List.iter iterations ~f:(fun iters ->
      let trace, result = infer_and_predict !last_trace iters xs_train ys_train xs_test ys_test npred_held_out in
      last_trace := trace ;
      Queue.enqueue statistics result ) ;
  Queue.to_list statistics
Out[26]:
val run_pipeline :
  T.t -> T.t -> T.t -> T.t -> int -> int -> int -> int -> result list = <fun>
In [27]:
let stats = run_pipeline xs_train ys_train xs_test ys_test 5 200 100 42
Out[27]:
val stats : result list =
  [{log_likelihood = -0.32733154296875;
    predictions_held_out =
     [<abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>;
      <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>;
      <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>;
      <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>;
      <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>;
      <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>;
      <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>;
      <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>;
      <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>;
      <abstr>; <abstr>; ...]};
   ...]
In [28]:
open Matplotlib
In [29]:
let plot () =
  let data = Mpl.plot_data `png in
  Mpl.clf ();
  ignore (Jupyter_notebook.display ~base64:true "image/png" data)

let () =
  Mpl.set_backend Agg;
  Mpl.style_use "ggplot"
Out[29]:
val plot : unit -> unit = <fun>
In [30]:
let likelihood, predictions = (List.last_exn stats).log_likelihood, (List.last_exn stats).predictions_held_out;;
Out[30]:
val likelihood : float = 105.283782958984375
val predictions : T.t list =
  [<abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>;
   <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>;
   <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>;
   <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>;
   <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>;
   <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>;
   <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>;
   <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>;
   <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>;
   <abstr>; <abstr>; ...]
In [31]:
let (x_min, y_min) = (1.949041666666666742e+03, 1.12e+02)

let (x_max, y_max) = (1.960958333333333258e+03, 4.32e+02)
Out[31]:
val x_min : float = 1949.04166666666674
val y_min : float = 112.
Out[31]:
val x_max : float = 1960.95833333333326
val y_max : float = 432.
In [32]:
let unscale_linear xs yl yh xl xh =
  let slope = Float.((yh - yl) / (xh - xl)) in
  let intercept = Float.(yh - xh * slope) in
  Array.map xs ~f:(fun x -> Float.((x - intercept) / slope))
Out[32]:
val unscale_linear :
  float Core.Array.t ->
  float -> float -> float -> float -> float Core.Array.t = <fun>
In [33]:
let xs_train_actual = unscale_linear (T.to_float1_exn xs_train) 0.0 1.0 x_min x_max

let ys_train_actual = unscale_linear (T.to_float1_exn ys_train) (-1.0) 1.0 y_min y_max

let xs_test_actual = unscale_linear (T.to_float1_exn xs_test) 0.0 1.0 x_min x_max

let ys_test_actual = unscale_linear (T.to_float1_exn ys_test) (-1.0) 1.0 y_min y_max

let predictions = List.map predictions ~f:(fun p -> unscale_linear (T.to_float1_exn p) (-1.0) 1.0 y_min y_max)
Out[33]:
val xs_train_actual : Base.float Core.Array.t =
  [|1949.04166666666674; 1949.1249999990689; 1949.20833333147084;
    1949.29166666387277; 1949.37499999627494; 1949.45833332867687;
    1949.54166666107881; 1949.62499999348097; 1949.7083333258829;
    1949.79166670267796; 1949.874999990687; 1949.95833336748206;
    1950.04166665549087; 1950.12500003228593; 1950.20833332029497;
    1950.29166660830401; 1950.37499998509907; 1950.45833327310788;
    1950.54166673868895; 1950.62499993791198; 1950.70833331470703;
    1950.79166669150209; 1950.87500006829714; 1950.95833326752;
    1951.041666644315; 1951.12499984353781; 1951.20833339790511;
    1951.29166659712791; 1951.37499997392297; 1951.45833335071802;
    1951.54166654994106; 1951.62499992673611; 1951.70833330353116;
    1951.79166668032622; 1951.87499987954902; 1951.95833343391632;
    1952.04166681071138; 1952.12500000993418; 1952.20833320915699;
    1952.29166676352429; 1952.3749999627471; 1952.4583335171144;
    1952.5416667163372; 1952.62499991556024; 1952.70833346992731;
    1952.79166666915035; 1952.87499986837315; 1952.95833342274045;
    1953.04166662196326; 1953.12500017633056; 1953.2083330204091;
    1953.29166657477617; 1953.37500012914347; 1953.45833297322201;
    1953.54166652758931; 1953.62500008195639; 1953.70833328117942;
    1953.79166648040223; 1953.87500003476953; 1953.95833323399233;
    1954.04166643321514; 1954.12499998758244; 1954.20833318680525;
    1954.29166674117255; 1954.37499994039536; 1954.45833313961839;
    1954.54166669398546; 1954.62500024835276; 1954.7083330924313;
    1954.7916666467986; 1954.87500020116568; 1954.95833304524422;
    1955.04166695475578; 1955.12500015397882; 1955.20833335320162;
    1955.29166655242443; 1955.37499975164747; 1955.45833295087027;
    1955.54166686038184; 1955.62500005960464; 1955.70833325882768;
    1955.79166645805049; 1955.87500036756205; 1955.95833285649633;
    1956.04166676600789; 1956.1249999652307; 1956.20833316445351;
    1956.29166636367654; 1956.37500027318811; 1956.45833276212215;
    1956.54166667163372; 1956.62499987085653; 1956.70833307007956;
    1956.79166626930237; 1956.87500017881393; 1956.95833337803674;
    1957.04166657725978; 1957.12499977648258; 1957.20833368599415;
    1957.29166688521696; 1957.37499937415123; 1957.4583332836628;
    1957.5416664828856; 1957.62499968210864; 1957.7083335916202;
    1957.79166679084301; 1957.87499927977728; 1957.95833389957738;
    1958.04166638851166; 1958.12499958773446; 1958.20833349724603;
    1958.29166669646906; 1958.37499989569187; 1958.45833380520344;
    1958.54166629413771; 1958.62499949336052; 1958.70833340287209;
    1958.79166660209489; 1958.87499980131793; 1958.95833371082949;
    1959.04166619976354; 1959.12499939898657; 1959.20833330849814;
    1959.29166650772095; 1959.37499970694375; 1959.45833361645532;
    1959.54166681567835; 1959.6249993046124; 1959.70833321412397|]
Out[33]:
val ys_train_actual : Base.float Core.Array.t =
  [|116.942083358764677; 120.648653030395536; 129.29729652404788;
    127.444021224975614; 122.501928329467802; 131.15058135986331;
    139.181472778320341; 139.181472778320341; 131.768342971801786;
    121.266414642334013; 112.000000000000028; 120.648653030395536;
    118.795368194580107; 125.590736389160185; 134.857141494751;
    131.15058135986331; 124.972974777221708; 139.799224853515653;
    152.772209167480497; 152.772209167480497; 145.359069824218778;
    129.915058135986357; 118.17760658264163; 134.239379882812528;
    137.328187942504911; 140.41698646545413; 157.714282989501982;
    148.447877883911161; 154.007722854614286; 157.714282989501982;
    170.687257766723661; 170.687257766723661; 161.420852661132841;
    147.830116271972685; 137.945949554443388; 150.301162719726591;
    153.38996124267581; 158.949806213378935; 166.980697631835966;
    159.567567825317411; 160.803091049194364; 182.424709320068388;
    189.837839126586942; 197.250968933105497; 176.864864349365263;
    165.745174407959; 154.007722854614286; 167.598459243774442;
    168.833972930908232; 168.833972930908232; 193.54440402984622;
    192.926647186279325; 189.220077514648466; 197.868730545043974;
    210.841700553894071; 215.783788681030302; 194.162165641784696;
    178.100387573242216; 158.949806213378935; 171.922780990600614;
    173.776065826416044; 163.891899108886747; 192.926647186279325;
    187.984554290771513; 192.308885574340849; 210.841700553894071;
    234.316603660583525; 228.7567586898804; 207.752897262573271;
    189.220077514648466; 173.158304214477567; 189.220077514648466;
    197.250968933105497; 191.691123962402372; 212.694985389709501;
    213.930503845214872; 214.548265457153349; 242.347492694854765;
    272.617761611938477; 262.115828514099121; 240.494210243225126;
    217.019307136535673; 194.162165641784696; 219.490348815918;
    223.196913719177275; 218.872587203979521; 243.583013534545927;
    241.111971855163603; 244.200775146484403; 278.795368194580078;
    302.888032913208; 297.945944786071777; 267.057916641235352;
    236.787645339965849; 215.166027069091825; 236.787645339965849;
    242.347492694854765; 233.698844432830839; 267.675673484802246;
    262.733590126037598; 267.057916641235352; 308.447877883911133;
    335.011579513549805; 336.247102737426758; 297.328183174133301;
    262.115828514099121; 236.169886112213163; 255.320465087890653;
    257.791506767272949; 244.200775146484403; 271.382238388061523;
    262.733590126037598; 272.; 316.478759765625; 351.073352813720703;
    359.722005844116211; 297.328183174133301; 269.528958320617676;
    239.258689403533964; 255.93822669982913; 270.14671516418457;
    259.027029991149902; 298.563706398010254; 292.386099815368652;
    307.21235466003418; 339.335901260376; 386.285707473754883;
    393.081075668334961; 333.776056289672852|]
Out[33]:
val xs_test_actual : Base.float Core.Array.t =
  [|1959.791666413347; 1959.87499961256981; 1959.95833352208138;
    1960.04166672130418; 1960.12499921023846; 1960.20833383003878;
    1960.29166631897283; 1960.37499951819586; 1960.45833342770743;
    1960.54166662693024; 1960.62499982615304; 1960.70833373566461;
    1960.79166622459888; 1960.87499942382169; 1960.95833333333326|]
Out[33]:
val ys_test_actual : Base.float Core.Array.t =
  [|299.18146800994873; 271.382238388061523; 297.945944786071777;
    305.35906982421875; 289.297296524047852; 306.594593048095703;
    332.540542602539062; 339.335901260376; 378.254825592041;
    431.999990463256836; 422.115823745727539; 361.575290679931641;
    332.540542602539062; 288.679534912109375; 314.625484466552734|]
Out[33]:
val predictions : Base.float Core.Array.t list =
  [[|330.707685470581055; 296.305815696716309; 289.742047309875488;
     269.697917222976685; 288.021586656570435; 305.400754928588867;
     333.346654891967773; 365.653602600097656; 352.545320510864258;
     374.761440277099609; 365.450279235839844; 325.374843597412109;
     284.445456981658936; 232.76416492462161; 254.147569656372099|];
   [|305.56135368347168; 310.397994041442871; 271.280216693878174;
     271.416354894638062; 243.90143585205081; 292.093305110931396;
     331.607682228088379; 340.525538444519043; 392.047702789306641;
     413.60710334777832; 343.402454376220703; 320.700194358825684;
     288.140811443328857; 298.792290210723877; 315.322978019714355|];
   [|311.504961967468262; 256.744551181793213; 258.004705429077148;
     272.146616697311401; 253.562920808792143; 308.243376731872559;
     339.759184837341309; 382.993051528930664; 425.648414611816406;
     404.158651351928711; 359.362737655639648; 298.415703296661377;
     287.539165735244751; 260.542180061340332; 251.64776039123538|];
   [|316.569625854492188; 306.975333213806152; 289.412371635437;
     296.205336570739746; 299.213249206542969; 264.367714643478394;
     330.784551620483398; 368.536989212036133; 372.405254364013672;
     449.065887451171875; 406.155559539794922; 349.764158248901367;
     286.690146446228; 268.889368295669556; 314.255387306213379|];
   [|291.383325576782227; 280.058149814605713; 276.112815260887146;
     258.113467216491699; 276.431219100952148; 268.97474479675293;
     267.851068496704102; 353.927709579467773; 356.940290451049805;
     403.720380783081055; 374.169103622436523; 341.738578796386719;
     239.242207527160673; 236.1877889633179; 264.331582307815552|];
   [|251.824357032775907; 246.976282119751; 265.330719470977783;
     215.617124557495146; 216.09635543823245; 237.677475929260282;
     255.283967971801786; 301.390163421630859; 355.059425354003906;
     357.357170104980469; 327.086846351623535; 286.805221557617188;
     252.471110343933134; 265.179692029953; 232.036540031433134|];
   [|300.057460784912109; 313.118640899658203; 295.407411575317383;
     236.287285804748564; 280.817483186721802; 271.616135358810425;
     327.780119895935059; 304.634372711181641; 353.432104110717773;
     397.650014877319336; 354.908363342285156; 344.106246948242188;
     341.013404846191406; 279.216455936431885; 280.242983818054199|];
   [|314.986373901367188; 273.83110237121582; 301.001011848449707;
     247.1067924499512; 236.972188472747831; 281.482946395874;
     330.980231285095215; 358.717042922973633; 382.003871917724609;
     391.6759033203125; 345.084416389465332; 311.646334648132324;
     284.724487781524658; 274.896474003791809; 269.188971042633057|];
   [|282.811762809753418; 299.479901313781738; 299.550399303436279;
     285.499090671539307; 271.299948811531067; 297.21726131439209;
     302.281913280487061; 317.203943252563477; 358.03973388671875;
     364.854824066162109; 350.330087661743164; 323.56013011932373;
     268.48432731628418; 282.003904104232788; 265.654129505157471|];
   [|289.388193607330322; 302.230047702789307; 299.07942008972168;
     277.325747132301331; 283.688425540924072; 334.436447143554688;
     322.680789947509766; 389.011528015136719; 417.827655792236328;
     379.5653076171875; 387.690689086914062; 319.211103439331055;
     301.773712158203125; 313.35810375213623; 314.166461944580078|];
   [|327.811681747436523; 310.869640827178955; 300.668398857116699;
     291.245830774307251; 271.239329099655151; 287.894865989685059;
     302.464890003204346; 385.771171569824219; 424.379446029663086;
     431.111652374267578; 399.149753570556641; 327.603251457214355;
     288.835534572601318; 293.560354232788086; 243.232724189758329|];
   [|311.109642505645752; 286.688019752502441; 320.729486465454102;
     239.373137474060087; 270.931279182434082; 302.865285396575928;
     342.611162185668945; 392.905208587646484; 383.18072509765625;
     360.137502670288086; 338.655392646789551; 266.257898330688477;
     277.587176084518433; 236.72835063934329; 235.435425758361845|];
   [|283.232359409332275; 295.798096179962158; 256.386773109436035;
     305.138818740844727; 303.815061569213867; 309.752096652984619;
     339.99015998840332; 379.773580551147461; 362.775251388549805;
     379.202529907226562; 341.031105041503906; 305.207225799560547;
     225.253965377807646; 258.494677066803; 260.132315874099731|];
   [|328.425619125366211; 288.945033073425293; 286.305486679077148;
     285.867502212524414; 307.104770660400391; 304.589497566223145;
     323.011648178100586; 356.541692733764648; 441.455013275146484;
     376.678211212158203; 368.185512542724609; 328.669301986694336;
     287.281509160995483; 302.689280033111572; 337.273323059082031|];
   [|305.708758354187; 305.353257179260254; 310.519530296325684;
     280.788464665412903; 311.73677396774292; 322.226292610168457;
     360.869590759277344; 367.05497932434082; 420.62858772277832;
     451.05792236328125; 400.608131408691406; 355.276233673095703;
     321.094934463501; 293.699397563934326; 304.913944721221924|];
   [|361.329433441162109; 328.553211212158203; 297.168945789337158;
     292.714173316955566; 266.588826894760132; 286.875962734222412;
     303.65482759475708; 337.194797515869141; 398.243124008178711;
     408.458187103271484; 407.662784576416; 385.289632797241211;
     304.29623556137085; 273.913884282112122; 307.929305553436279|];
   [|308.097619533538818; 277.473988056182861; 292.403227806091309;
     245.433790206909208; 244.465643882751493; 246.471123218536405;
     286.714024066925049; 357.822563171386719; ...|];
   ...]
In [34]:
let fig, ax = Fig.create_with_ax () in
Ax.plot ax ~label:"Observed Data" ~color:Black ~linestyle:(Other "--") ~marker:'.' ~xs:xs_train_actual ys_train_actual;
List.iter predictions ~f:(fun p -> Ax.plot ax ~color:Green ~linestyle:(Other "-") ~alpha:0.025 ~xs:xs_test_actual p);
Ax.plot ax ~label:"Predictions" ~color:Green [||];
Ax.scatter ax ~label:"Held-out Data" ~marker:'+' ~s:30.0 ~c:Red (Array.zip_exn xs_test_actual ys_test_actual);
Ax.legend ax ~loc:UpperLeft ~framealpha:0.0 ();
Ax.set_xlabel ax ~fontsize:12.0 "Year";
Ax.set_ylabel ax ~fontsize:12.0 "Passenger Volume";
Ax.set_xlim ax ~left:1947.5 ();
Ax.set_xticks ax ~labels:[|"1948"; "1950"; "1952"; "1954"; "1956"; "1958"; "1960"|] [|1948.;1950.;1952.;1954.;1956.;1958.;1960.|];
Fig.set_tight_layout fig true;
Fig.set_size_inches fig ~w:4.0 ~h:3.0;
plot ()
Out[34]:
- : unit = ()
In [ ]: