I think we can change the tiling comm part. I moved the tile start and end calculation after the sync point. Also convert the variables to locally defined vars. Test them separately.
for (comm_tile_idx = 0; comm_tile_idx < num_comm_tiles; comm_tile_idx++) {
int cur_iter_comm_tile_flag_idx = comm_tile_idx + cur_iter_mod * num_flags;
if (cta.thread_rank() == 0) {
while (local_is_top_neighbor_done_writing_to_me[cur_iter_comm_tile_flag_idx] !=
iter) {
}
}
cg::sync(cta);
int comm_tile_start = (comm_tile_idx == 0) ? 1 : comm_tile_idx * comm_tile_size;
int comm_tile_end = (comm_tile_idx == (num_comm_tiles - 1))
? nx - 1
: (comm_tile_idx + 1) * comm_tile_size;
int col = threadIdx.y * blockDim.x + threadIdx.x + comm_tile_start;
if (col < comm_tile_end) {
const real first_row_val =
0.25 * (a[iy_start * nx + col + 1] + a[iy_start * nx + col - 1] +
a[(iy_start + 1) * nx + col] +
remote_my_halo_buffer_on_top_neighbor[nx * cur_iter_mod + col]);
a_new[iy_start * nx + col] = first_row_val;
local_halo_buffer_for_top_neighbor[nx * next_iter_mod + col] = first_row_val;
}
cg::sync(cta);
if (cta.thread_rank() == 0) {
int next_iter_comm_tile_flag_idx =
(num_comm_tiles + comm_tile_idx) + next_iter_mod * num_flags;
remote_am_done_writing_to_top_neighbor[next_iter_comm_tile_flag_idx] = iter + 1;
}
}