Skip to content

Commit

Permalink
Fixes for pool-allocated view sizes
Browse files Browse the repository at this point in the history
  • Loading branch information
jgfouca committed Jan 13, 2025
1 parent 9d6d25b commit 3751726
Showing 1 changed file with 54 additions and 40 deletions.
94 changes: 54 additions & 40 deletions components/eamxx/src/physics/rrtmgp/scream_rrtmgp_interface.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -409,45 +409,51 @@ static void rrtmgp_main(
const int int_size2 = sw_nband;
const int int_size3 = 2*lw_nband;
const int int_size4 = lw_nband;
const int int_size5 = sw_ngpt;
const int int_size6 = lw_ngpt;

const int real_size1 = ncol*nlay*sw_nband;
const int real_size2 = ncol*nlay*lw_nband;
const int real_size3 = ncol*nlay*sw_ngpt;
const int real_size4 = ncol*nlay*lw_ngpt;

const int total_int_size = 3 * (int_size1 + int_size2 + int_size3 + int_size4);
const int total_real_size = 3 * (3 * real_size1 + real_size2);
const int total_int_size = 2 * (int_size1 + int_size2 + int_size3 + int_size4) + (int_size1 + int_size5 + int_size3 + int_size6);
const int total_real_size = 2 * (3 * real_size1 + real_size2) + (3*real_size3 + real_size4);
auto int_data = pool_t::template alloc_and_init<int>(total_int_size); int *dcurr_int = int_data.data();

view_t<int**> sw_band2gpt_mem(dcurr_int, 2, sw_nband); dcurr_int += int_size1;
view_t<int*> sw_gpt2band_mem(dcurr_int, sw_nband); dcurr_int += int_size2;
view_t<int*> sw_gpt2band_mem(dcurr_int, sw_nband); dcurr_int += int_size2;
view_t<int**> lw_band2gpt_mem(dcurr_int, 2, lw_nband); dcurr_int += int_size3;
view_t<int*> lw_gpt2band_mem(dcurr_int, lw_nband); dcurr_int += int_size4;
view_t<int*> lw_gpt2band_mem(dcurr_int, lw_nband); dcurr_int += int_size4;

view_t<int**> sw_cloud_band2gpt_mem(dcurr_int, 2, sw_nband); dcurr_int += int_size1;
view_t<int*> sw_cloud_gpt2band_mem(dcurr_int, sw_nband); dcurr_int += int_size2;
view_t<int*> sw_cloud_gpt2band_mem(dcurr_int, sw_nband); dcurr_int += int_size2;
view_t<int**> lw_cloud_band2gpt_mem(dcurr_int, 2, lw_nband); dcurr_int += int_size3;
view_t<int*> lw_cloud_gpt2band_mem(dcurr_int, lw_nband); dcurr_int += int_size4;
view_t<int*> lw_cloud_gpt2band_mem(dcurr_int, lw_nband); dcurr_int += int_size4;

view_t<int**> sw_subcloud_band2gpt_mem(dcurr_int, 2, sw_nband); dcurr_int += int_size1;
view_t<int*> sw_subcloud_gpt2band_mem(dcurr_int, sw_ngpt); dcurr_int += int_size2;
view_t<int*> sw_subcloud_gpt2band_mem(dcurr_int, sw_ngpt); dcurr_int += int_size5;
view_t<int**> lw_subcloud_band2gpt_mem(dcurr_int, 2, lw_nband); dcurr_int += int_size3;
view_t<int*> lw_subcloud_gpt2band_mem(dcurr_int, lw_ngpt); dcurr_int += int_size4;
view_t<int*> lw_subcloud_gpt2band_mem(dcurr_int, lw_ngpt); dcurr_int += int_size6;
assert(dcurr_int - int_data.data() == total_int_size);

auto data = pool_t::template alloc<RealT>(total_real_size); RealT *dcurr = data.data();
auto data = pool_t::template alloc_and_init<RealT>(total_real_size); RealT *dcurr = data.data();

view_t<RealT***> sw_tau_mem(dcurr, ncol, nlay, sw_nband); dcurr += real_size1;
view_t<RealT***> sw_ssa_mem(dcurr, ncol, nlay, sw_nband); dcurr += real_size1;
view_t<RealT***> sw_g_mem(dcurr, ncol, nlay, sw_nband); dcurr += real_size1;
view_t<RealT***> sw_g_mem (dcurr, ncol, nlay, sw_nband); dcurr += real_size1;
view_t<RealT***> lw_tau_mem(dcurr, ncol, nlay, lw_nband); dcurr += real_size2;

view_t<RealT***> sw_cloud_tau_mem(dcurr, ncol, nlay, sw_nband); dcurr += real_size1;
view_t<RealT***> sw_cloud_ssa_mem(dcurr, ncol, nlay, sw_nband); dcurr += real_size1;
view_t<RealT***> sw_cloud_g_mem(dcurr, ncol, nlay, sw_nband); dcurr += real_size1;
view_t<RealT***> sw_cloud_g_mem (dcurr, ncol, nlay, sw_nband); dcurr += real_size1;
view_t<RealT***> lw_cloud_tau_mem(dcurr, ncol, nlay, lw_nband); dcurr += real_size2;

view_t<RealT***> sw_subcloud_tau_mem(dcurr, ncol, nlay, sw_ngpt); dcurr += real_size1;
view_t<RealT***> sw_subcloud_ssa_mem(dcurr, ncol, nlay, sw_ngpt); dcurr += real_size1;
view_t<RealT***> sw_subcloud_g_mem(dcurr, ncol, nlay, sw_ngpt); dcurr += real_size1;
view_t<RealT***> lw_subcloud_tau_mem(dcurr, ncol, nlay, lw_ngpt); dcurr += real_size2;
view_t<RealT***> sw_subcloud_tau_mem(dcurr, ncol, nlay, sw_ngpt); dcurr += real_size3;
view_t<RealT***> sw_subcloud_ssa_mem(dcurr, ncol, nlay, sw_ngpt); dcurr += real_size3;
view_t<RealT***> sw_subcloud_g_mem (dcurr, ncol, nlay, sw_ngpt); dcurr += real_size3;
view_t<RealT***> lw_subcloud_tau_mem(dcurr, ncol, nlay, lw_ngpt); dcurr += real_size4;
assert(dcurr - data.data() == total_real_size);

// Setup pointers to RRTMGP SW fluxes
fluxes_t fluxes_sw;
Expand Down Expand Up @@ -532,6 +538,7 @@ static void rrtmgp_main(
// subcolumn (cloud state) to each gpoint.
auto nswgpts = k_dist_sw_k.get_ngpt();
auto clouds_sw_gpt = get_subsampled_clouds(ncol, nlay, nswbands, nswgpts, clouds_sw, k_dist_sw_k, cldfrac, p_lay, sw_subcloud_band2gpt_mem, sw_subcloud_gpt2band_mem, sw_subcloud_tau_mem, sw_subcloud_ssa_mem, sw_subcloud_g_mem);

// Longwave
auto nlwgpts = k_dist_lw_k.get_ngpt();
auto clouds_lw_gpt = get_subsampled_clouds(ncol, nlay, nlwbands, nlwgpts, clouds_lw, k_dist_lw_k, cldfrac, p_lay, lw_subcloud_band2gpt_mem, lw_subcloud_gpt2band_mem, lw_subcloud_tau_mem);
Expand Down Expand Up @@ -729,21 +736,23 @@ static void rrtmgp_sw(
auto sw_noaero_tau_mem = view_t<RealT***>(dcurr, nday, nlay, ngpt); dcurr += size11;
auto sw_noaero_ssa_mem = view_t<RealT***>(dcurr, nday, nlay, ngpt); dcurr += size11;
auto sw_noaero_g_mem = view_t<RealT***>(dcurr, nday, nlay, ngpt); dcurr += size11;
assert(dcurr - data.data() == total_size);

const int int_size1 = 2*nbnd;
const int int_size2 = nbnd;
const int int_size3 = ngpt;
const int total_int_size = 3 * (int_size1 + int_size3) + (int_size1 + int_size2);
auto int_data = pool_t::template alloc_and_init<int>(total_int_size); int *dcurr_int = int_data.data();

auto sw_aero_band2gpt_mem = view_t<int**>(dcurr_int, 2, nbnd); dcurr_int += int_size1;
auto sw_aero_gpt2band_mem = view_t<int*>(dcurr_int, nbnd); dcurr_int += int_size2;
auto sw_cloud_band2gpt_mem = view_t<int**>(dcurr_int, 2, nbnd); dcurr_int += int_size1;
auto sw_cloud_gpt2band_mem = view_t<int*>(dcurr_int, ngpt); dcurr_int += int_size3;
auto sw_aero_band2gpt_mem = view_t<int**>(dcurr_int, 2, nbnd); dcurr_int += int_size1;
auto sw_aero_gpt2band_mem = view_t<int*> (dcurr_int, nbnd); dcurr_int += int_size2;
auto sw_cloud_band2gpt_mem = view_t<int**>(dcurr_int, 2, nbnd); dcurr_int += int_size1;
auto sw_cloud_gpt2band_mem = view_t<int*> (dcurr_int, ngpt); dcurr_int += int_size3;
auto sw_optics_band2gpt_mem = view_t<int**>(dcurr_int, 2, nbnd); dcurr_int += int_size1;
auto sw_optics_gpt2band_mem = view_t<int*>(dcurr_int, ngpt); dcurr_int += int_size3;
auto sw_optics_gpt2band_mem = view_t<int*> (dcurr_int, ngpt); dcurr_int += int_size3;
auto sw_noaero_band2gpt_mem = view_t<int**>(dcurr_int, 2, nbnd); dcurr_int += int_size1;
auto sw_noaero_gpt2band_mem = view_t<int*>(dcurr_int, ngpt); dcurr_int += int_size3;
auto sw_noaero_gpt2band_mem = view_t<int*> (dcurr_int, ngpt); dcurr_int += int_size3;
assert(dcurr_int - int_data.data() == total_int_size);

// Subset mu0
TIMED_KERNEL(Kokkos::parallel_for(nday, KOKKOS_LAMBDA(int iday) {
Expand Down Expand Up @@ -942,31 +951,33 @@ static void rrtmgp_lw(
const int total_size = size1 + size2 + size3*2 + size4 + size5 + size6 + size7*5 + size8;
auto data = pool_t::template alloc_and_init<RealT>(total_size); RealT *dcurr = data.data();

view_t<RealT*> t_sfc (dcurr, ncol); dcurr += size1;
view_t<RealT**> emis_sfc (dcurr, nbnd,ncol); dcurr += size2;
view_t<RealT**> gauss_Ds (dcurr, max_gauss_pts,max_gauss_pts); dcurr += size3;
view_t<RealT**> gauss_wts (dcurr, max_gauss_pts,max_gauss_pts); dcurr += size3;
view_t<RealT**> t_lay_limited(dcurr, ncol, nlay); dcurr += size4;
view_t<RealT**> t_lev_limited(dcurr, ncol, nlay+1); dcurr += size5;
view_t<RealT***> col_gas (dcurr, ncol, nlay, k_dist.get_ngas()+1); dcurr += size6;
view_t<RealT***> lw_optics_tau_mem(dcurr, ncol, nlay, ngpt); dcurr += size7;
view_t<RealT***> lw_noaero_tau_mem(dcurr, ncol, nlay, ngpt); dcurr += size7;
view_t<RealT***> lay_source_mem(dcurr, ncol, nlay, ngpt); dcurr += size7;
view_t<RealT***> lev_source_inc_mem(dcurr, ncol, nlay, ngpt); dcurr += size7;
view_t<RealT***> lev_source_dec_mem(dcurr, ncol, nlay, ngpt); dcurr += size7;
view_t<RealT**> sfc_source_mem(dcurr, ncol, ngpt); dcurr += size8;
view_t<RealT*> t_sfc (dcurr, ncol); dcurr += size1;
view_t<RealT**> emis_sfc (dcurr, nbnd, ncol); dcurr += size2;
view_t<RealT**> gauss_Ds (dcurr, max_gauss_pts, max_gauss_pts); dcurr += size3;
view_t<RealT**> gauss_wts (dcurr, max_gauss_pts, max_gauss_pts); dcurr += size3;
view_t<RealT**> t_lay_limited (dcurr, ncol, nlay); dcurr += size4;
view_t<RealT**> t_lev_limited (dcurr, ncol, nlay+1); dcurr += size5;
view_t<RealT***> col_gas (dcurr, ncol, nlay, k_dist.get_ngas()+1); dcurr += size6;
view_t<RealT***> lw_optics_tau_mem (dcurr, ncol, nlay, ngpt); dcurr += size7;
view_t<RealT***> lw_noaero_tau_mem (dcurr, ncol, nlay, ngpt); dcurr += size7;
view_t<RealT***> lay_source_mem (dcurr, ncol, nlay, ngpt); dcurr += size7;
view_t<RealT***> lev_source_inc_mem (dcurr, ncol, nlay, ngpt); dcurr += size7;
view_t<RealT***> lev_source_dec_mem (dcurr, ncol, nlay, ngpt); dcurr += size7;
view_t<RealT**> sfc_source_mem (dcurr, ncol, ngpt); dcurr += size8;
assert(dcurr - data.data() == total_size);

const int int_size1 = 2*nbnd;
const int int_size2 = ngpt;
const int total_int_size = 3 * (int_size1 + int_size2);
auto int_data = pool_t::template alloc_and_init<int>(total_int_size); int *dcurr_int = int_data.data();

auto lw_optics_band2gpt_mem = view_t<int**>(dcurr_int, 2, nbnd);
auto lw_optics_gpt2band_mem = view_t<int*>(dcurr_int, ngpt);
auto lw_noaero_band2gpt_mem = view_t<int**>(dcurr_int, 2, nbnd);
auto lw_noaero_gpt2band_mem = view_t<int*>(dcurr_int, ngpt);
auto lw_source_band2gpt_mem = view_t<int**>(dcurr_int, 2, nbnd);
auto lw_source_gpt2band_mem = view_t<int*>(dcurr_int, ngpt);
auto lw_optics_band2gpt_mem = view_t<int**>(dcurr_int, 2, nbnd); dcurr_int += int_size1;
auto lw_optics_gpt2band_mem = view_t<int*> (dcurr_int, ngpt); dcurr_int += int_size2;
auto lw_noaero_band2gpt_mem = view_t<int**>(dcurr_int, 2, nbnd); dcurr_int += int_size1;
auto lw_noaero_gpt2band_mem = view_t<int*> (dcurr_int, ngpt); dcurr_int += int_size2;
auto lw_source_band2gpt_mem = view_t<int**>(dcurr_int, 2, nbnd); dcurr_int += int_size1;
auto lw_source_gpt2band_mem = view_t<int*> (dcurr_int, ngpt); dcurr_int += int_size2;
assert(dcurr_int - int_data.data() == total_int_size);

// Associate local pointers for fluxes
auto &flux_up = fluxes.flux_up;
Expand Down Expand Up @@ -1002,6 +1013,7 @@ static void rrtmgp_lw(
// Allocate space for optical properties
optical_props1_t optics;
optics.alloc_1scl_no_alloc(ncol, nlay, k_dist, lw_optics_band2gpt_mem, lw_optics_gpt2band_mem, lw_optics_tau_mem);

optical_props1_t optics_no_aerosols;
if (extra_clnsky_diag) {
// Allocate space for optical properties (no aerosols)
Expand Down Expand Up @@ -1490,6 +1502,7 @@ static optical_props2_t get_subsampled_clouds(
cldfrac_rad(icol,ilay) = cld(icol,ilay);
}
}));

// Get subcolumn cloud mask; note that get_subcolumn_mask exposes overlap assumption as an option,
// but the only currently supported options are 0 (trivial all-or-nothing cloud) or 1 (max-rand),
// so overlap has not been exposed as an option beyond this subcolumn. In the future, we should
Expand All @@ -1503,6 +1516,7 @@ static optical_props2_t get_subsampled_clouds(
seeds(icol) = 1e9 * (p_lay(icol,nlay-1) - int(p_lay(icol,nlay-1)));
}));
get_subcolumn_mask(ncol, nlay, ngpt, cldfrac_rad, overlap, seeds, cldmask);

// Assign optical properties to subcolumns (note this implements MCICA)
auto gpoint_bands = kdist.get_gpoint_bands();
TIMED_KERNEL(Kokkos::parallel_for(MDRP::template get<3>({ncol,nlay,ngpt}), KOKKOS_LAMBDA(int icol, int ilay, int igpt) {
Expand Down

0 comments on commit 3751726

Please sign in to comment.