#require "core,torch_ext,owl,matplotlib,jupyter.notebook"
open Core
open Torch_ext
open Owl
module T = Tensor
module S = Scalar
module D = Distributions
module T = Torch_ext.Tensor
module S = Torch_ext.Scalar
module D = Torch_ext.Distributions
let bernoulli p = T.bernoulli_float ~p [] |> T.bool_value
let rand () = T.rand [] |> T.float_value
let categorical probs =
let u = rand () in
let exception Ret of int in
try
ignore (List.foldi probs ~init:0.0 ~f:(fun i acc p -> if Float.(u < acc + p) then raise (Ret i) else Float.(acc + p)) : float) ;
assert false
with Ret i -> i
let gamma a b =
let d = D.gamma (T.f a) (T.f b) in
d#sample () |> T.float_value
val bernoulli : Base.float -> Base.bool = <fun>
val rand : unit -> Base.float = <fun>
val categorical : Core.Float.t list -> int = <fun>
val gamma : Base.float -> Base.float -> Base.float = <fun>
type kern =
| Constant of float
| Linear of float
| Squared_exponential of float
| Periodic of float * float
| Plus of kern * kern
| Times of kern * kern
type kern = Constant of float | Linear of float | Squared_exponential of float | Periodic of float * float | Plus of kern * kern | Times of kern * kern
let rec size (k : kern) : int =
match k with
| Constant _ | Linear _ | Squared_exponential _ | Periodic _ ->
1
| Plus (left, right) | Times (left, right) ->
size left + size right + 1
val size : kern -> int = <fun>
let rec eval_cov_mat (k : kern) (xs : T.t) : T.t =
match k with
| Constant param ->
let n = T.shape1_exn xs in
T.new_ones ~scale:(S.f param) [n; n]
| Linear param ->
let xs_minus_param = T.sub_scalar xs (S.f param) in
T.outer xs_minus_param ~vec2:xs_minus_param
| Squared_exponential length_scal ->
let self_diff =
let tmp = T.(outer xs ~vec2:(ones_like xs)) in
T.(tmp - tr tmp)
in
T.(exp (T.div_scalar (T.scale_f (self_diff * self_diff) (-0.5)) (S.f length_scal)))
| Periodic (scal, period) ->
let freq = Float.(2.0 * pi / period) in
let abs_self_diff =
let tmp = T.outer xs ~vec2:(T.ones_like xs) in
T.(abs (tmp - tr tmp))
in
T.(exp (scale_f (square (sin (scale_f abs_self_diff freq))) Float.(-1.0 / scal)))
| Plus (left, right) ->
T.(eval_cov_mat left xs + eval_cov_mat right xs)
| Times (left, right) ->
T.(eval_cov_mat left xs * eval_cov_mat right xs)
val eval_cov_mat : kern -> T.t -> T.t = <fun>
let compute_cov_matrix_vectorized (k : kern) (noise : float) (xs : T.t) : T.t =
let n = T.shape1_exn xs in
T.(eval_cov_mat k xs + new_eye ~scale:(S.f noise) n)
val compute_cov_matrix_vectorized : kern -> float -> T.t -> T.t = <fun>
let compute_log_likelihood (k : kern) (noise : float) (xs : T.t) (ys : T.t) : float =
let mu = T.zeros_like xs in
let cov = compute_cov_matrix_vectorized k noise xs in
let scale = T.linalg_cholesky cov in
let mvn = D.multivariate_normal mu scale in
mvn#log_prob ys |> T.float_value
val compute_log_likelihood : kern -> float -> T.t -> T.t -> float = <fun>
let rec covariance_prior () : kern =
let node_type = categorical [0.2; 0.2; 0.2; 0.2; 0.1; 0.1] in
match node_type with
| 0 ->
Constant (rand ())
| 1 ->
Linear (rand ())
| 2 ->
Squared_exponential (rand ())
| 3 ->
Periodic (rand (), rand ())
| 4 ->
Plus (covariance_prior (), covariance_prior ())
| 5 ->
Times (covariance_prior (), covariance_prior ())
| _ ->
assert false
val covariance_prior : unit -> kern = <fun>
let rec pick_random_node_unbiased (k : kern) (cur : int) : int * kern =
match k with
| Constant _ | Linear _ | Squared_exponential _ | Periodic _ ->
(cur, k)
| Plus (left, right) | Times (left, right) -> (
let sz = size k |> Float.of_int in
let probs = [Float.(1.0 / sz); Float.(of_int (size left) / sz); Float.(of_int (size right) / sz)] in
let choice = categorical probs in
match choice with
| 0 ->
(cur, k)
| 1 ->
pick_random_node_unbiased left (cur * 2)
| 2 ->
pick_random_node_unbiased right ((cur * 2) + 1)
| _ ->
assert false )
val pick_random_node_unbiased : kern -> int -> int * kern = <fun>
let rec replace_subtree (k : kern) (cur : int) ~(to_ : kern) ~(on_ : int) : kern =
match k with
| Constant _ | Linear _ | Squared_exponential _ | Periodic _ ->
if cur = on_ then to_ else k
| Plus (left, right) ->
if cur = on_ then to_
else Plus (replace_subtree left (cur * 2) ~to_ ~on_, replace_subtree right ((cur * 2) + 1) ~to_ ~on_)
| Times (left, right) ->
if cur = on_ then to_
else Times (replace_subtree left (cur * 2) ~to_ ~on_, replace_subtree right ((cur * 2) + 1) ~to_ ~on_)
val replace_subtree : kern -> int -> to_:kern -> on_:int -> kern = <fun>
let get_alpha_subtree_unbiased (prev : kern) (prop : kern) : float =
Float.(log (of_int (size prev)) - log (of_int (size prop)))
val get_alpha_subtree_unbiased : kern -> kern -> float = <fun>
type trace = {cov_k: kern; noise: float; xs: T.t; ys: T.t; log_likelihood: float}
type trace = { cov_k : kern; noise : float; xs : T.t; ys : T.t; log_likelihood : float; }
let mh_resample_subtree_unbiased (prev_trace : trace) : trace =
let i_delta, k_delta = pick_random_node_unbiased prev_trace.cov_k 1 in
let subtree = covariance_prior () in
let cov_k_new = replace_subtree prev_trace.cov_k 1 ~to_:subtree ~on_:i_delta in
let log_likelihood = compute_log_likelihood cov_k_new prev_trace.noise prev_trace.xs prev_trace.ys in
let new_trace = {cov_k= cov_k_new; noise= prev_trace.noise; xs= prev_trace.xs; ys= prev_trace.ys; log_likelihood} in
let alpha_size = get_alpha_subtree_unbiased k_delta subtree in
let alpha_ll = Float.(new_trace.log_likelihood - prev_trace.log_likelihood) in
let alpha = Float.(alpha_ll + alpha_size) in
if Float.(log (rand ()) < alpha) then new_trace else prev_trace
val mh_resample_subtree_unbiased : trace -> trace = <fun>
let mh_resample_noise (prev_trace : trace) : trace =
let noise_new = Float.(gamma 1.0 1.0 + 0.01) in
let log_likelihood = compute_log_likelihood prev_trace.cov_k noise_new prev_trace.xs prev_trace.ys in
let new_trace = {cov_k= prev_trace.cov_k; noise= noise_new; xs= prev_trace.xs; ys= prev_trace.ys; log_likelihood} in
let alpha = Float.(new_trace.log_likelihood - prev_trace.log_likelihood) in
if Float.(log (rand ()) < alpha) then new_trace else prev_trace
val mh_resample_noise : trace -> trace = <fun>
let initialize_trace (xs : T.t) (ys : T.t) : trace =
let cov_k = covariance_prior () in
let noise = Float.(gamma 1.0 1.0 + 0.01) in
let log_likelihood = compute_log_likelihood cov_k noise xs ys in
{cov_k; noise; xs; ys; log_likelihood}
val initialize_trace : T.t -> T.t -> trace = <fun>
let run_mcmc (prev_trace : trace) (iters : int) : trace =
let new_trace = ref prev_trace in
for _ = 1 to iters do
new_trace := mh_resample_subtree_unbiased !new_trace ;
new_trace := mh_resample_noise !new_trace
done ;
!new_trace
val run_mcmc : trace -> int -> trace = <fun>
let rescale_linear (xs : T.t) (yl : float) (yh : float) : T.t =
let xl = T.minimum xs |> T.float_value in
let xh = T.maximum xs |> T.float_value in
let slope = Float.((yh - yl) / (xh - xl)) in
let intercept = Float.(yh - (xh * slope)) in
T.(add_scalar (scale_f xs slope) (S.f intercept))
val rescale_linear : T.t -> float -> float -> T.t = <fun>
let load_dataset_from_path (path : string) (n_test : int) : (T.t * T.t) * (T.t * T.t) =
let df = Dataframe.of_csv ~sep:',' ~types:[|"f"; "f"|] path in
let xs = rescale_linear (Dataframe.get_col_by_name df "x" |> Dataframe.unpack_float_series |> T.of_float1) 0.0 1.0 in
let ys = rescale_linear (Dataframe.get_col_by_name df "y" |> Dataframe.unpack_float_series |> T.of_float1) (-1.0) 1.0 in
let n = T.shape1_exn xs in
let[@warning "-8"] [xs_train; xs_test] = T.split_with_sizes ~split_sizes:[n - n_test; n_test] xs in
let[@warning "-8"] [ys_train; ys_test] = T.split_with_sizes ~split_sizes:[n - n_test; n_test] ys in
((xs_train, ys_train), (xs_test, ys_test))
val load_dataset_from_path : string -> int -> (T.t * T.t) * (T.t * T.t) = <fun>
let (xs_train, ys_train), (xs_test, ys_test) = load_dataset_from_path "airline.csv" 15
2022-11-30 09:32:13.887 WARN : Owl_io.head: ignored exception "Assert_failure src/base/misc/owl_io.ml:138:9"
val xs_train : T.t = <abstr> val ys_train : T.t = <abstr> val xs_test : T.t = <abstr> val ys_test : T.t = <abstr>
let get_conditional_mu_cov (cov_k : kern) (noise : float) (xs : T.t) (ys : T.t) (new_xs : T.t) : T.t * T.t =
let n_prev = T.shape1_exn xs in
let n_new = T.shape1_exn new_xs in
let cov_matrix = compute_cov_matrix_vectorized cov_k noise (T.cat [xs; new_xs]) |> T.to_float2_exn in
let cov_matrix11 = Array.init n_prev ~f:(fun i -> Array.init n_prev ~f:(fun j -> cov_matrix.(i).(j))) |> T.of_float2 in
let cov_matrix22 = Array.init n_new ~f:(fun i -> Array.init n_new ~f:(fun j -> cov_matrix.(n_prev + i).(n_prev + j))) |> T.of_float2 in
let cov_matrx12 = Array.init n_prev ~f:(fun i -> Array.init n_new ~f:(fun j -> cov_matrix.(i).(n_prev + j))) |> T.of_float2 in
let cov_matrix21 = Array.init n_new ~f:(fun i -> Array.init n_prev ~f:(fun j -> cov_matrix.(n_prev + i).(j))) |> T.of_float2 in
let mu1 = T.zeros [n_prev] in
let mu2 = T.zeros [n_new] in
let conditional_mu = T.(mu2 + matmul cov_matrix21 (T.linalg_solve cov_matrix11 (ys - mu1))) in
let conditional_cov_matrix = T.(cov_matrix22 - matmul cov_matrix21 (T.linalg_solve cov_matrix11 cov_matrx12)) in
let conditional_cov_matrix = T.(scale_f conditional_cov_matrix 0.5 + scale_f (tr conditional_cov_matrix) 0.5) in
(conditional_mu, conditional_cov_matrix)
val get_conditional_mu_cov : kern -> float -> T.t -> T.t -> T.t -> T.t * T.t = <fun>
let compute_log_likelihood_predictive (cov_k : kern) (noise : float) (xs : T.t) (ys : T.t) (new_xs : T.t) (new_ys : T.t) : float =
let mu, cov = get_conditional_mu_cov cov_k noise xs ys new_xs in
let scale = T.linalg_cholesky cov in
let d = D.multivariate_normal mu scale in
d#log_prob new_ys |> T.float_value
val compute_log_likelihood_predictive : kern -> float -> T.t -> T.t -> T.t -> T.t -> float = <fun>
let gp_predictive_samples (cov_k : kern) (noise : float) (xs : T.t) (ys : T.t) (new_xs : T.t) (n : int) : T.t list =
let mu, cov = get_conditional_mu_cov cov_k noise xs ys new_xs in
let scale = T.linalg_cholesky cov in
let d = D.multivariate_normal mu scale in
List.init n ~f:(fun _ -> d#sample ())
val gp_predictive_samples : kern -> float -> T.t -> T.t -> T.t -> int -> T.t list = <fun>
type result =
{ log_likelihood: float
; predictions_held_out: T.t list }
type result = { log_likelihood : float; predictions_held_out : T.t list; }
let infer_and_predict (trace : trace) (iters : int) (xs_train : T.t) (ys_train : T.t) (xs_test : T.t) (ys_test : T.t) (npred_out : int) : trace * result =
let trace = run_mcmc trace iters in
let cov_k, noise = (trace.cov_k, trace.noise) in
let predictions_held_out = gp_predictive_samples cov_k noise xs_train ys_train xs_test npred_out in
let result =
{ log_likelihood= trace.log_likelihood
; predictions_held_out= predictions_held_out }
in
(trace, result)
val infer_and_predict : trace -> int -> T.t -> T.t -> T.t -> T.t -> int -> trace * result = <fun>
let run_pipeline (xs_train : T.t) (ys_train : T.t) (xs_test : T.t) (ys_test : T.t) (iters : int) (epochs : int) (npred_held_out : int) (seed : int) : result list =
let iterations = List.init epochs ~f:(fun _ -> iters) in
Torch_core.Wrapper.manual_seed seed ;
let trace = initialize_trace xs_train ys_train in
let statistics = Queue.create () in
let last_trace = ref trace in
List.iter iterations ~f:(fun iters ->
let trace, result = infer_and_predict !last_trace iters xs_train ys_train xs_test ys_test npred_held_out in
last_trace := trace ;
Queue.enqueue statistics result ) ;
Queue.to_list statistics
val run_pipeline : T.t -> T.t -> T.t -> T.t -> int -> int -> int -> int -> result list = <fun>
let stats = run_pipeline xs_train ys_train xs_test ys_test 5 200 100 42
val stats : result list = [{log_likelihood = -0.32733154296875; predictions_held_out = [<abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; ...]}; ...]
open Matplotlib
let plot () =
let data = Mpl.plot_data `png in
Mpl.clf ();
ignore (Jupyter_notebook.display ~base64:true "image/png" data)
let () =
Mpl.set_backend Agg;
Mpl.style_use "ggplot"
val plot : unit -> unit = <fun>
let likelihood, predictions = (List.last_exn stats).log_likelihood, (List.last_exn stats).predictions_held_out;;
val likelihood : float = 105.283782958984375 val predictions : T.t list = [<abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; <abstr>; ...]
let (x_min, y_min) = (1.949041666666666742e+03, 1.12e+02)
let (x_max, y_max) = (1.960958333333333258e+03, 4.32e+02)
val x_min : float = 1949.04166666666674 val y_min : float = 112.
val x_max : float = 1960.95833333333326 val y_max : float = 432.
let unscale_linear xs yl yh xl xh =
let slope = Float.((yh - yl) / (xh - xl)) in
let intercept = Float.(yh - xh * slope) in
Array.map xs ~f:(fun x -> Float.((x - intercept) / slope))
val unscale_linear : float Core.Array.t -> float -> float -> float -> float -> float Core.Array.t = <fun>
let xs_train_actual = unscale_linear (T.to_float1_exn xs_train) 0.0 1.0 x_min x_max
let ys_train_actual = unscale_linear (T.to_float1_exn ys_train) (-1.0) 1.0 y_min y_max
let xs_test_actual = unscale_linear (T.to_float1_exn xs_test) 0.0 1.0 x_min x_max
let ys_test_actual = unscale_linear (T.to_float1_exn ys_test) (-1.0) 1.0 y_min y_max
let predictions = List.map predictions ~f:(fun p -> unscale_linear (T.to_float1_exn p) (-1.0) 1.0 y_min y_max)
val xs_train_actual : Base.float Core.Array.t = [|1949.04166666666674; 1949.1249999990689; 1949.20833333147084; 1949.29166666387277; 1949.37499999627494; 1949.45833332867687; 1949.54166666107881; 1949.62499999348097; 1949.7083333258829; 1949.79166670267796; 1949.874999990687; 1949.95833336748206; 1950.04166665549087; 1950.12500003228593; 1950.20833332029497; 1950.29166660830401; 1950.37499998509907; 1950.45833327310788; 1950.54166673868895; 1950.62499993791198; 1950.70833331470703; 1950.79166669150209; 1950.87500006829714; 1950.95833326752; 1951.041666644315; 1951.12499984353781; 1951.20833339790511; 1951.29166659712791; 1951.37499997392297; 1951.45833335071802; 1951.54166654994106; 1951.62499992673611; 1951.70833330353116; 1951.79166668032622; 1951.87499987954902; 1951.95833343391632; 1952.04166681071138; 1952.12500000993418; 1952.20833320915699; 1952.29166676352429; 1952.3749999627471; 1952.4583335171144; 1952.5416667163372; 1952.62499991556024; 1952.70833346992731; 1952.79166666915035; 1952.87499986837315; 1952.95833342274045; 1953.04166662196326; 1953.12500017633056; 1953.2083330204091; 1953.29166657477617; 1953.37500012914347; 1953.45833297322201; 1953.54166652758931; 1953.62500008195639; 1953.70833328117942; 1953.79166648040223; 1953.87500003476953; 1953.95833323399233; 1954.04166643321514; 1954.12499998758244; 1954.20833318680525; 1954.29166674117255; 1954.37499994039536; 1954.45833313961839; 1954.54166669398546; 1954.62500024835276; 1954.7083330924313; 1954.7916666467986; 1954.87500020116568; 1954.95833304524422; 1955.04166695475578; 1955.12500015397882; 1955.20833335320162; 1955.29166655242443; 1955.37499975164747; 1955.45833295087027; 1955.54166686038184; 1955.62500005960464; 1955.70833325882768; 1955.79166645805049; 1955.87500036756205; 1955.95833285649633; 1956.04166676600789; 1956.1249999652307; 1956.20833316445351; 1956.29166636367654; 1956.37500027318811; 1956.45833276212215; 1956.54166667163372; 1956.62499987085653; 1956.70833307007956; 1956.79166626930237; 1956.87500017881393; 1956.95833337803674; 1957.04166657725978; 1957.12499977648258; 1957.20833368599415; 1957.29166688521696; 1957.37499937415123; 1957.4583332836628; 1957.5416664828856; 1957.62499968210864; 1957.7083335916202; 1957.79166679084301; 1957.87499927977728; 1957.95833389957738; 1958.04166638851166; 1958.12499958773446; 1958.20833349724603; 1958.29166669646906; 1958.37499989569187; 1958.45833380520344; 1958.54166629413771; 1958.62499949336052; 1958.70833340287209; 1958.79166660209489; 1958.87499980131793; 1958.95833371082949; 1959.04166619976354; 1959.12499939898657; 1959.20833330849814; 1959.29166650772095; 1959.37499970694375; 1959.45833361645532; 1959.54166681567835; 1959.6249993046124; 1959.70833321412397|]
val ys_train_actual : Base.float Core.Array.t = [|116.942083358764677; 120.648653030395536; 129.29729652404788; 127.444021224975614; 122.501928329467802; 131.15058135986331; 139.181472778320341; 139.181472778320341; 131.768342971801786; 121.266414642334013; 112.000000000000028; 120.648653030395536; 118.795368194580107; 125.590736389160185; 134.857141494751; 131.15058135986331; 124.972974777221708; 139.799224853515653; 152.772209167480497; 152.772209167480497; 145.359069824218778; 129.915058135986357; 118.17760658264163; 134.239379882812528; 137.328187942504911; 140.41698646545413; 157.714282989501982; 148.447877883911161; 154.007722854614286; 157.714282989501982; 170.687257766723661; 170.687257766723661; 161.420852661132841; 147.830116271972685; 137.945949554443388; 150.301162719726591; 153.38996124267581; 158.949806213378935; 166.980697631835966; 159.567567825317411; 160.803091049194364; 182.424709320068388; 189.837839126586942; 197.250968933105497; 176.864864349365263; 165.745174407959; 154.007722854614286; 167.598459243774442; 168.833972930908232; 168.833972930908232; 193.54440402984622; 192.926647186279325; 189.220077514648466; 197.868730545043974; 210.841700553894071; 215.783788681030302; 194.162165641784696; 178.100387573242216; 158.949806213378935; 171.922780990600614; 173.776065826416044; 163.891899108886747; 192.926647186279325; 187.984554290771513; 192.308885574340849; 210.841700553894071; 234.316603660583525; 228.7567586898804; 207.752897262573271; 189.220077514648466; 173.158304214477567; 189.220077514648466; 197.250968933105497; 191.691123962402372; 212.694985389709501; 213.930503845214872; 214.548265457153349; 242.347492694854765; 272.617761611938477; 262.115828514099121; 240.494210243225126; 217.019307136535673; 194.162165641784696; 219.490348815918; 223.196913719177275; 218.872587203979521; 243.583013534545927; 241.111971855163603; 244.200775146484403; 278.795368194580078; 302.888032913208; 297.945944786071777; 267.057916641235352; 236.787645339965849; 215.166027069091825; 236.787645339965849; 242.347492694854765; 233.698844432830839; 267.675673484802246; 262.733590126037598; 267.057916641235352; 308.447877883911133; 335.011579513549805; 336.247102737426758; 297.328183174133301; 262.115828514099121; 236.169886112213163; 255.320465087890653; 257.791506767272949; 244.200775146484403; 271.382238388061523; 262.733590126037598; 272.; 316.478759765625; 351.073352813720703; 359.722005844116211; 297.328183174133301; 269.528958320617676; 239.258689403533964; 255.93822669982913; 270.14671516418457; 259.027029991149902; 298.563706398010254; 292.386099815368652; 307.21235466003418; 339.335901260376; 386.285707473754883; 393.081075668334961; 333.776056289672852|]
val xs_test_actual : Base.float Core.Array.t = [|1959.791666413347; 1959.87499961256981; 1959.95833352208138; 1960.04166672130418; 1960.12499921023846; 1960.20833383003878; 1960.29166631897283; 1960.37499951819586; 1960.45833342770743; 1960.54166662693024; 1960.62499982615304; 1960.70833373566461; 1960.79166622459888; 1960.87499942382169; 1960.95833333333326|]
val ys_test_actual : Base.float Core.Array.t = [|299.18146800994873; 271.382238388061523; 297.945944786071777; 305.35906982421875; 289.297296524047852; 306.594593048095703; 332.540542602539062; 339.335901260376; 378.254825592041; 431.999990463256836; 422.115823745727539; 361.575290679931641; 332.540542602539062; 288.679534912109375; 314.625484466552734|]
val predictions : Base.float Core.Array.t list = [[|330.707685470581055; 296.305815696716309; 289.742047309875488; 269.697917222976685; 288.021586656570435; 305.400754928588867; 333.346654891967773; 365.653602600097656; 352.545320510864258; 374.761440277099609; 365.450279235839844; 325.374843597412109; 284.445456981658936; 232.76416492462161; 254.147569656372099|]; [|305.56135368347168; 310.397994041442871; 271.280216693878174; 271.416354894638062; 243.90143585205081; 292.093305110931396; 331.607682228088379; 340.525538444519043; 392.047702789306641; 413.60710334777832; 343.402454376220703; 320.700194358825684; 288.140811443328857; 298.792290210723877; 315.322978019714355|]; [|311.504961967468262; 256.744551181793213; 258.004705429077148; 272.146616697311401; 253.562920808792143; 308.243376731872559; 339.759184837341309; 382.993051528930664; 425.648414611816406; 404.158651351928711; 359.362737655639648; 298.415703296661377; 287.539165735244751; 260.542180061340332; 251.64776039123538|]; [|316.569625854492188; 306.975333213806152; 289.412371635437; 296.205336570739746; 299.213249206542969; 264.367714643478394; 330.784551620483398; 368.536989212036133; 372.405254364013672; 449.065887451171875; 406.155559539794922; 349.764158248901367; 286.690146446228; 268.889368295669556; 314.255387306213379|]; [|291.383325576782227; 280.058149814605713; 276.112815260887146; 258.113467216491699; 276.431219100952148; 268.97474479675293; 267.851068496704102; 353.927709579467773; 356.940290451049805; 403.720380783081055; 374.169103622436523; 341.738578796386719; 239.242207527160673; 236.1877889633179; 264.331582307815552|]; [|251.824357032775907; 246.976282119751; 265.330719470977783; 215.617124557495146; 216.09635543823245; 237.677475929260282; 255.283967971801786; 301.390163421630859; 355.059425354003906; 357.357170104980469; 327.086846351623535; 286.805221557617188; 252.471110343933134; 265.179692029953; 232.036540031433134|]; [|300.057460784912109; 313.118640899658203; 295.407411575317383; 236.287285804748564; 280.817483186721802; 271.616135358810425; 327.780119895935059; 304.634372711181641; 353.432104110717773; 397.650014877319336; 354.908363342285156; 344.106246948242188; 341.013404846191406; 279.216455936431885; 280.242983818054199|]; [|314.986373901367188; 273.83110237121582; 301.001011848449707; 247.1067924499512; 236.972188472747831; 281.482946395874; 330.980231285095215; 358.717042922973633; 382.003871917724609; 391.6759033203125; 345.084416389465332; 311.646334648132324; 284.724487781524658; 274.896474003791809; 269.188971042633057|]; [|282.811762809753418; 299.479901313781738; 299.550399303436279; 285.499090671539307; 271.299948811531067; 297.21726131439209; 302.281913280487061; 317.203943252563477; 358.03973388671875; 364.854824066162109; 350.330087661743164; 323.56013011932373; 268.48432731628418; 282.003904104232788; 265.654129505157471|]; [|289.388193607330322; 302.230047702789307; 299.07942008972168; 277.325747132301331; 283.688425540924072; 334.436447143554688; 322.680789947509766; 389.011528015136719; 417.827655792236328; 379.5653076171875; 387.690689086914062; 319.211103439331055; 301.773712158203125; 313.35810375213623; 314.166461944580078|]; [|327.811681747436523; 310.869640827178955; 300.668398857116699; 291.245830774307251; 271.239329099655151; 287.894865989685059; 302.464890003204346; 385.771171569824219; 424.379446029663086; 431.111652374267578; 399.149753570556641; 327.603251457214355; 288.835534572601318; 293.560354232788086; 243.232724189758329|]; [|311.109642505645752; 286.688019752502441; 320.729486465454102; 239.373137474060087; 270.931279182434082; 302.865285396575928; 342.611162185668945; 392.905208587646484; 383.18072509765625; 360.137502670288086; 338.655392646789551; 266.257898330688477; 277.587176084518433; 236.72835063934329; 235.435425758361845|]; [|283.232359409332275; 295.798096179962158; 256.386773109436035; 305.138818740844727; 303.815061569213867; 309.752096652984619; 339.99015998840332; 379.773580551147461; 362.775251388549805; 379.202529907226562; 341.031105041503906; 305.207225799560547; 225.253965377807646; 258.494677066803; 260.132315874099731|]; [|328.425619125366211; 288.945033073425293; 286.305486679077148; 285.867502212524414; 307.104770660400391; 304.589497566223145; 323.011648178100586; 356.541692733764648; 441.455013275146484; 376.678211212158203; 368.185512542724609; 328.669301986694336; 287.281509160995483; 302.689280033111572; 337.273323059082031|]; [|305.708758354187; 305.353257179260254; 310.519530296325684; 280.788464665412903; 311.73677396774292; 322.226292610168457; 360.869590759277344; 367.05497932434082; 420.62858772277832; 451.05792236328125; 400.608131408691406; 355.276233673095703; 321.094934463501; 293.699397563934326; 304.913944721221924|]; [|361.329433441162109; 328.553211212158203; 297.168945789337158; 292.714173316955566; 266.588826894760132; 286.875962734222412; 303.65482759475708; 337.194797515869141; 398.243124008178711; 408.458187103271484; 407.662784576416; 385.289632797241211; 304.29623556137085; 273.913884282112122; 307.929305553436279|]; [|308.097619533538818; 277.473988056182861; 292.403227806091309; 245.433790206909208; 244.465643882751493; 246.471123218536405; 286.714024066925049; 357.822563171386719; ...|]; ...]
let fig, ax = Fig.create_with_ax () in
Ax.plot ax ~label:"Observed Data" ~color:Black ~linestyle:(Other "--") ~marker:'.' ~xs:xs_train_actual ys_train_actual;
List.iter predictions ~f:(fun p -> Ax.plot ax ~color:Green ~linestyle:(Other "-") ~alpha:0.025 ~xs:xs_test_actual p);
Ax.plot ax ~label:"Predictions" ~color:Green [||];
Ax.scatter ax ~label:"Held-out Data" ~marker:'+' ~s:30.0 ~c:Red (Array.zip_exn xs_test_actual ys_test_actual);
Ax.legend ax ~loc:UpperLeft ~framealpha:0.0 ();
Ax.set_xlabel ax ~fontsize:12.0 "Year";
Ax.set_ylabel ax ~fontsize:12.0 "Passenger Volume";
Ax.set_xlim ax ~left:1947.5 ();
Ax.set_xticks ax ~labels:[|"1948"; "1950"; "1952"; "1954"; "1956"; "1958"; "1960"|] [|1948.;1950.;1952.;1954.;1956.;1958.;1960.|];
Fig.set_tight_layout fig true;
Fig.set_size_inches fig ~w:4.0 ~h:3.0;
plot ()
- : unit = ()