We notice a buggy behavior with bitcasts and dynamic update slices. When we turn on activation checkpointing (e.g., saving outputs of projection layers using the SAVE_OUT_PROJ
flag in PAXML) we see multiple extra updates and copies.
For example, we want to checkpoint an activation of shape [2,2048,48,128]. However, in the HLO below we see that the copies are of shape [15,1,2,2048,48,128]. Here, 15 is the number of microbatches we are using with pipeline parallelism.
fusion.549 = (bf16[15,1,2,2048,48,128]{3,5,4,2,1,0}, bf16[15,1,2,2048,48,128]{3,5,4,2,1,0}, ..., kind=kLoop, calls=fused_computation.549, metadata={op_name="pjit(_wrapped_step_fn)/jit(main)/jvp(xformer_lm.apply)/xformer_lm/xformer_lm.compute_predictions/lm/transformer/pipeline/while/body/dynamic_update_slice" source_file="/usr/local/lib/python3.8/dist-packages/flax/core/axes_scan.py" source_line=148}
get-tuple-element.5874 = bf16[15,1,2,2048,48,128]{3,5,4,2,1,0} get-tuple-element(fusion.549), index=0
copy.583 = bf16[15,1,2,2048,48,128]{3,5,4,2,1,0} copy(get-tuple-element.5874)
get-tuple-element.5866 = bf16[15,1,2,2048,48,128]{3,5,4,2,1,0} get-tuple-element(fusion.549), index=1
copy.575 = bf16[15,1,2,2048,48,128]{3,5,4,2,1,0} copy(get-tuple-element.5866)
get-tuple-element.5868 = bf16[15,1,2,2048,48,128]{3,5,4,2,1,0} get-tuple-element(fusion.549), index=2
copy.577 = bf16[15,1,2,2048,48,128]{3,5,4,2,1,0} copy(get-tuple-element.5868)
get-tuple-element.5870 = bf16[15,1,2,2048,48,128]{3,5,4,2,1,0} get-tuple-element(fusion.549), index=3
copy.579 = bf16[15,1,2,2048,48,128]{3,5,4,2,1,0} copy(get-tuple-element.5870)
get-tuple-element.5872 = bf16[15,1,2,2048,48,128]{3,5,4,2,1,0} get-tuple-element(fusion.549), index=4
copy.581 = bf16[15,1,2,2048,48,128]{3,5,4,2,1,0} copy(get-tuple-element.5872)
...
fused_computation.549 {
param_1.8511 = bf16[15,1,2,2048,48,128]{3,5,4,2,1,0} parameter(1)
bitcast.52601 = bf16[15,1,2,48,128,2048]{5,4,3,2,1,0} bitcast(param_1.8511)
param_0.6313 = bf16[2,48,128,2048]{3,2,1,0} parameter(0)
bitcast.52600 = bf16[1,1,2,48,128,2048]{5,4,3,2,1,0} bitcast(param_0.6313)
param_2.5901 = s32[] parameter(2)
constant_7564 = s32[] constant(0)
compare.3477 = pred[] compare(param_2.5901, constant_7564), direction=LT, metadata={op_name="pjit(_wrapped_step_fn)/jit(main)/jvp(xformer_lm.apply)/xformer_lm/xformer_lm.compute_predictions/lm/transformer/pipeline/while/body/pipeline._scan_fn/pipeline._get_iteration_inputs/jit(remainder)/rem" source_file="/pax/praxis/praxis/layers/pipeline.py" source_line=422}
constant_11524 = s32[] constant(15)
add.6580 = s32[] add(param_2.5901, constant_11524), metadata={op_name="pjit(_wrapped_step_fn)/jit(main)/jvp(xformer_lm.apply)/xformer_lm/xformer_lm.compute_predictions/lm/transformer/pipeline/while/body/add" source_file="/pax/praxis/praxis/base_layer.py" source_line=695}
select.5360 = s32[] select(compare.3477, add.6580, param_2.5901), metadata={op_name="pjit(_wrapped_step_fn)/jit(main)/jvp(xformer_lm.apply)/xformer_lm/xformer_lm.compute_predictions/lm/transformer/pipeline/while/body/select_n" source_file="/usr/local/lib/python3.8/dist-packages/flax/core/axes_scan.py" source_line=148}
dynamic-update-slice.325 = bf16[15,1,2,48,128,2048]{5,4,3,2,1,0} dynamic-update-slice(bitcast.52601, bitcast.52600, select.5360, constant_7564, constant_7564, /*index=5*/constant_7564, constant_7564, constant_7564), metadata={op_name="pjit(_wrapped_step_fn)/jit(main)/jvp(xformer_lm.apply)/xformer_lm/xformer_lm.compute_predictions/lm/transformer/pipeline/while/body/dynamic_update_slice" source_file="/usr/local/lib/python3.8/dist-packages/flax/core/axes_scan.py" source_line=148}
bitcast.52599 = bf16[15,1,2,2048,48,128]{3,5,4,2,1,0} bitcast(dynamic-update-slice.325), metadata={op_name="pjit(_wrapped_step_fn)/jit(main)/jvp(xformer_lm.apply)/xformer_lm/xformer_lm.compute_predictions/lm/transformer/pipeline/while/body/dynamic_update_slice" source_file="/usr/local/lib/python3.8/dist-packages/flax/core/axes_scan.py" source_line=148}
param_4.7770 = bf16[15,1,2,2048,48,128]{3,5,4,2,1,0} parameter(4)
bitcast.52617.clone.1 = bf16[15,1,2,48,128,2048]{5,4,3,2,1,0} bitcast(param_4.7770)
param_3.8428 = bf16[2,48,128,2048]{3,2,1,0} parameter(3)
bitcast.52616.clone.1 = bf16[1,1,2,48,128,2048]{5,4,3,2,1,0} bitcast(param_3.8428)
dynamic-update-slice.333.clone.1 = bf16[15,1,2,48,128,2048]{5,4,3,2,1,0} dynamic-update-slice(bitcast.52617.clone.1, bitcast.52616.clone.1, select.5360, constant_7564, constant_7564, /*index=5*/constant_7564, constant_7564, constant_7564), metadata={op_name="pjit(_wrapped_step_fn)/jit(main)/jvp(xformer_lm.apply)/xformer_lm/xformer_lm.compute_predictions/lm/transformer/pipeline/while/body/dynamic_update_slice" source_file="/usr/local/lib/python3.8/dist-packages/flax/core/axes_scan.py" source_line=148}
...
ROOT tuple.356 = (bf16[15,1,2,2048,48,128]{3,5,4,2,1,0}, bf16[15,1,2,2048,48,128]{3,5,4,2,1,0}, bf16[15,1,2,2048,48,128]{3,5,4,2,1,0}, bf16[15,1,2,2048,48,128]{3,5,4,2,1,0}, bf16[15,1,2,2048,48,128]{3,5,4,2,1,0}) tuple(bitcast.52599, bitcast.52615.clone.1, bitcast.52611.clone.1, bitcast.52607.clone.1, bitcast.52603.clone.1)
}
It seems like there is a big buffer of size [15,1,2,2048,48,128] holding the activations for all microbatches. Within each microbatch, we are trying to update one row of this buffer (of shape [2,2048,48,128]). But XLA loads the entire buffer into memory, performs the update, and then copies the buffer back. We see this problem in our profiles. The amount of time spent on D2D copies (i.e., copy.575 to copy.583) is much larger than expected for the amount of data that should be copied. Right now, the time spent on activation checkpointing is 5% to 8% of the overall run time for a GPT-3 style model.
Our current understanding: The reason for the copy is because when bitcast is treated as computing a new value (e.g., like a convert or sqrt), then a new tensor must be used in each loop iteration, therefore a copy of each DUS result must be made. This should be able to be fixed by treating bitcast as an aliasing operation instead of computing a new value --- in the dataflow analysis. I think there is an option in dataflow analysis that configures how bitcast should be treated. In XLA TPU, the option is set to be true where bitcasts are treated as simply an aliasing operation.