nrposner

Optimizing Astronomy Utilities in Rust, Continued

A week ago, I wrote a quick blog post on some attempts at Rust optimization of C-Python functions in the basil-core codebase for calculating orbital decay functions, and found that a quite straightforward, one-to-one conversion of the existing C code with PyO3 resulted in code that was simpler, more readable, and in many cases faster.

In the past week, Vera and I have been working more closely, and I've applied the same treatment to all the C array functions in the basil-core/astro directory. See the results here, or in the following GINORMOUS ASCII DATAFRAME.

=== Benchmark Results ===
┌─────────────────┬────────────┬─────────────────┬────────────────┬────────────────┬────────────────┬────────────────┬────────────────┐
│ function_name   ┆ array_size ┆ c_runtime_ns_to ┆ rust_runtime_n ┆ c_ns_per_eleme ┆ rust_ns_per_el ┆ speedup_factor ┆ passed_equalit │
│ ---             ┆ ---        ┆ tal             ┆ s_total        ┆ nt             ┆ ement          ┆ ---            ┆ y_check        │
│ str             ┆ i64        ┆ ---             ┆ ---            ┆ ---            ┆ ---            ┆ f64            ┆ ---            │
│                 ┆            ┆ i64             ┆ i64            ┆ f64            ┆ f64            ┆                ┆ bool           │
╞═════════════════╪════════════╪═════════════════╪════════════════╪════════════════╪════════════════╪════════════════╪════════════════╡
│ beta_arr        ┆ 10000      ┆ 46275767        ┆ 24130873       ┆ 0.462758       ┆ 0.241309       ┆ 1.9177         ┆ true           │
│ orb_sep_evol_ci ┆ 10000      ┆ 116488159       ┆ 58884571       ┆ 1.164882       ┆ 0.588846       ┆ 1.978246       ┆ true           │
│ rc_arr          ┆            ┆                 ┆                ┆                ┆                ┆                ┆                │
│ peters_ecc_cons ┆ 10000      ┆ 1301626275      ┆ 1270440543     ┆ 13.016263      ┆ 12.704405      ┆ 1.024547       ┆ true           │
│ t_arr           ┆            ┆                 ┆                ┆                ┆                ┆                ┆                │
│ peters_ecc_inte ┆ 10000      ┆ 1894272007      ┆ 1910899683     ┆ 18.94272       ┆ 19.108997      ┆ 0.991299       ┆ true           │
│ grand_arr       ┆            ┆                 ┆                ┆                ┆                ┆                ┆                │
│ merge_time_circ ┆ 10000      ┆ 73751355        ┆ 37069788       ┆ 0.737514       ┆ 0.370698       ┆ 1.989527       ┆ true           │
│ _arr            ┆            ┆                 ┆                ┆                ┆                ┆                ┆                │
│ orb_sep_evol_ec ┆ 10000      ┆ 2326365439      ┆ 1917944653     ┆ 23.263654      ┆ 19.179447      ┆ 1.212947       ┆ true           │
│ c_integrand_arr ┆            ┆                 ┆                ┆                ┆                ┆                ┆                │
│ orbital_period_ ┆ 10000      ┆ 94656531        ┆ 47753358       ┆ 0.946565       ┆ 0.477534       ┆ 1.982196       ┆ true           │
│ of_m1_m2_a_arr  ┆            ┆                 ┆                ┆                ┆                ┆                ┆                │
│ dwd_r_arr       ┆ 10000      ┆ 551290596       ┆ 516216091      ┆ 5.512906       ┆ 5.162161       ┆ 1.067945       ┆ true           │
│ dwd_a_arr       ┆ 10000      ┆ 971168924       ┆ 928001166      ┆ 9.711689       ┆ 9.280012       ┆ 1.046517       ┆ true           │
│ dwd_p_arr       ┆ 10000      ┆ 1129603683      ┆ 1036108061     ┆ 11.296037      ┆ 10.361081      ┆ 1.090237       ┆ true           │
│ mc_of_m1_m2     ┆ 10000      ┆ 1357332236      ┆ 1349078009     ┆ 13.573322      ┆ 13.49078       ┆ 1.006118       ┆ true           │
│ eta_of_m1_m2    ┆ 10000      ┆ 47549433        ┆ 26281491       ┆ 0.475494       ┆ 0.262815       ┆ 1.809237       ┆ true           │
│ mc_and_eta_of_m ┆ 10000      ┆ 1404119967      ┆ 1355773537     ┆ 14.0412        ┆ 13.557735      ┆ 1.03566        ┆ true           │
│ 1_m2            ┆            ┆                 ┆                ┆                ┆                ┆                ┆                │
│ m_of_mc_eta     ┆ 10000      ┆ 731157034       ┆ 714706382      ┆ 7.31157        ┆ 7.147064       ┆ 1.023017       ┆ true           │
│ m1_of_m_eta     ┆ 10000      ┆ 64740428        ┆ 31114962       ┆ 0.647404       ┆ 0.31115        ┆ 2.080685       ┆ true           │
│ m2_of_m_eta     ┆ 10000      ┆ 64335915        ┆ 31086978       ┆ 0.643359       ┆ 0.31087        ┆ 2.069545       ┆ true           │
│ source_of_detec ┆ 10000      ┆ 40128554        ┆ 24012530       ┆ 0.401286       ┆ 0.2401253      ┆ 1.671151       ┆ true           │
│ tor             ┆            ┆                 ┆                ┆                ┆                ┆                ┆                │
│ detector_of_sou ┆ 10000      ┆ 40180218        ┆ 24031362       ┆ 0.401802       ┆ 0.240314       ┆ 1.671991       ┆ true           │
│ rce             ┆            ┆                 ┆                ┆                ┆                ┆                ┆                │
│ lambda_tilde_of ┆ 10000      ┆ 109649498       ┆ 71451001       ┆ 1.096495       ┆ 0.71451        ┆ 1.534611       ┆ true           │
│ _eta_lambda1_la ┆            ┆                 ┆                ┆                ┆                ┆                ┆                │
│ …               ┆            ┆                 ┆                ┆                ┆                ┆                ┆                │
│ delta_lambda_of ┆ 10000      ┆ 172836582       ┆ 92303447       ┆ 1.728366       ┆ 0.923034       ┆ 1.872482       ┆ true           │
│ _eta_lambda1_la ┆            ┆                 ┆                ┆                ┆                ┆                ┆                │
│ …               ┆            ┆                 ┆                ┆                ┆                ┆                ┆                │
│ lambda1_of_eta_ ┆ 10000      ┆ 261128634       ┆ 153712614      ┆ 2.611286       ┆ 1.537126       ┆ 1.698811       ┆ false          │
│ lambda_tilde_de ┆            ┆                 ┆                ┆                ┆                ┆                ┆                │
│ …               ┆            ┆                 ┆                ┆                ┆                ┆                ┆                │
│ lambda2_of_eta_ ┆ 10000      ┆ 263326639       ┆ 155489233      ┆ 2.633266       ┆ 1.554892       ┆ 1.693536       ┆ false          │
│ lambda_tilde_de ┆            ┆                 ┆                ┆                ┆                ┆                ┆                │
│ …               ┆            ┆                 ┆                ┆                ┆                ┆                ┆                │
│ chieff_of_m1_m2 ┆ 10000      ┆ 59579577        ┆ 42373168       ┆ 0.595796       ┆ 0.423732       ┆ 1.406069       ┆ true           │
│ _chi1z_chi2z    ┆            ┆                 ┆                ┆                ┆                ┆                ┆                │
│ chiMinus_of_m1_ ┆ 10000      ┆ 66508631        ┆ 42341161       ┆ 0.665086       ┆ 0.423412       ┆ 1.57078        ┆ true           │
│ m2_chi1z_chi2z  ┆            ┆                 ┆                ┆                ┆                ┆                ┆                │
│ chi1z_of_m1_m2_ ┆ 10000      ┆ 68969821        ┆ 42357896       ┆ 0.689698       ┆ 0.423579       ┆ 1.628264       ┆ true           │
│ chieff_chiMinus ┆            ┆                 ┆                ┆                ┆                ┆                ┆                │
│ chi2z_of_m1_m2_ ┆ 10000      ┆ 67015676        ┆ 42345472       ┆ 0.670157       ┆ 0.423455       ┆ 1.582594       ┆ true           │
│ chieff_chiMinus ┆            ┆                 ┆                ┆                ┆                ┆                ┆                │
└─────────────────┴────────────┴─────────────────┴────────────────┴────────────────┴────────────────┴────────────────┴────────────────┘

Each of these functions were tested 10,000 times on arrays of length 10,000, with randomly initialized inputs, and their results confirmed to be equal to within 4 decimal points.

Broadly: of the 26 C array functions we converted (all of which are basically identical in implementation, just reading double slices of various numpy arrays and iterating over them to output a single new double numpy array) we have:

  • One function that shows a small slowdown in Rust compared to C: that being peters_ecc_integrand_arr, with a consistent slowdown of 1%.
  • Seven functions show a small (<10%) speedup in Rust compared to C, generally in the range of 2-8%.
  • 18 functions show a large (>10%) speedup in Rust compared to C, generally in the range of 40-80%.

This is a pretty interesting bimodal distribution: there are functions here for which our new implementation has a small effect, if any. I haven't gone and looked at the Assembly yet, but I'll bet that the generated instructions are going to be pretty similar, with Rust potentially dropping some bounds checks or make some kind of numerical optimizations that GCC at opt-level=3 isn't making. I'd be quite interested to see where the Rust slowdown for peters_ecc_integrand_arr is coming from.

The functions in this category tend to be those making very simple calculations, particularly those which make less use of integer exponentiation (the original C code uses pow(), which I am given to believe is float operation, whereas the Rust code uses .powi() where possible) as well as those iterating over a single array.

Then there are the functions which show a distinct speedup, no less than 40% and going up from there. These, too, are of two categories, though the distinction isn't visible in the dataframe above: when these benchmarks are run at varying array sizes, some of these exhibit consistent speedups, like orb_sep_evol_ecc_integrand_arr from the last blog post, while others have a more unstable, but still reliably large, speedup, which tends to grow as the size of the arrays grows.

These are generally functions making heavy use of integer exponentiation, or iterating over three or four arrays simultaneously.

The biggest win, besides performance is just plain ease of readability: I invite you to read through the implementations in the two codebases, and note that the Rust code does the same work without having to manually invoke a ton of pointer magic: this was also the observation from Cliff Biffle's 'Learn Rust the Dangerous Way'.

At the same time, Vera managed to revamp orb_sep_evol_ecc_integrand_arr to make it dramatically faster, going from around 60ns/element to 36ns/element, now just a smidge slower than Rust's 30ns/element. Interestingly, applying the same optimization to Rust made it run a hair SLOWER than C, down to around 37ns/element.

That said, it's not all sunshine and roses: two of these Rust functions exhibit numerical instability on the input range I was using. I checked the implementations, and I'm pretty sure the C code has the same issues, but because of the difference in how LLVM and GCC handle that instability, many of the results differ by a lot more than one part in a thousand.

Ultimately, that leaves us with sixteen strong candidates for Rust conversion, and seven weak candidates.

The next step, which Vera is interested in, would be to integrate Rust directly into basil-core, which builds with meson-python. However, we haven't found a clear example of a codebase that already does it, and talking to the folks over at meson, this seems like the kind of thing that should be possible in the future, but it's going to take some core maintainers (read: actual experts) doing a bunch of works to make it stable. So for the time being, it's on the shelf.

Ah, I think my model has finally finished training. See you all next week.