I'm trying to run Ring Attention on a machine with 6 A100 GPUs, and I'm finding that when I try to set the sequence parallelism dimension to anything other than a power of 2, the process crashes with a JAX partitioning error.
I'd be grateful for any insight into whether or not I'm doing something wrong in the way I'm invoking the training script, and for any advice on how to work around this issue.
Steps to Reproduce
Consider the following script for invoking llamabpt.train:
#########################################################
### Configuration 1 (runs successfully) ###
#########################################################
export CUDA_VISIBLE_DEVICES="0,1,2,3"
SEQ_PAR_DIM=4
MAX_SEQ_LEN=131072
#########################################################
#########################################################
### Configuration 2 (CRASHES with partitioning error) ###
#########################################################
# export CUDA_VISIBLE_DEVICES="0,1,2"
# SEQ_PAR_DIM=3
# MAX_SEQ_LEN=98304
#########################################################
python3 -m llamabpt.train \
--mesh_dim="1,1,1,${SEQ_PAR_DIM}" \
--dtype=bf16 \
--load_llama_config=1b \
--update_llama_config="{'max_sequence_length': ${MAX_SEQ_LEN}, 'scan_attention': True, 'scan_query_chunk_size': 2048, 'scan_key_chunk_size': 4096, 'remat_attention': 'nothing_saveable', 'scan_mlp': True, 'scan_mlp_chunk_size': 2048, 'remat_mlp': 'nothing_saveable', 'remat_block': 'nothing_saveable', 'scan_layers': True, 'attention_type': 'ring_blockwise', 'param_scan_axis': 0, 'mesh_dim': '1,1,1,${SEQ_PAR_DIM}'}" \
--total_steps=2 \
--log_freq=1 \
--save_model_freq=0 \
--save_milestone_freq=1000 \
--tokenizer.vocab_file="${TRAIN_DATA_PATH}" \
--optimizer.type=adamw \
--optimizer.adamw_optimizer.weight_decay=0.1 \
--optimizer.adamw_optimizer.lr=1.5e-4 \
--optimizer.adamw_optimizer.end_lr=1.5e-5 \
--optimizer.adamw_optimizer.lr_warmup_steps=1 \
--optimizer.adamw_optimizer.lr_decay_steps=10 \
--train_dataset.type=json \
--train_dataset.text_processor.fields=text \
--train_dataset.json_dataset.path="${TOKENIZER_PATH}" \
--train_dataset.json_dataset.seq_length=${MAX_SEQ_LEN} \
--train_dataset.json_dataset.batch_size=1 \
--train_dataset.json_dataset.tokenizer_processes=16
For Configuration 1, where the sequence parallelism dimension is 4, the training script runs as expected without errors.
However, when I uncomment Configuration 2, where the sequence parallelism dimension is 3, the training script crashes with the following error:
ValueError: One of pjit outputs with pytree key path .params['params']['lm_head']['kernel'] was given the sharding of NamedSharding(mesh={'dp': 1, 'fsdp': 1, 'tp': 1, 'sp': 3}, spec=PartitionSpec(('fsdp', 'sp'), 'tp')), which implies that the global size of its dimension 0 should be divisible by 3, but it is equal to 2048 (full shape: (2048, 32000))
The error occurs during the first call to sharded_init_fn.
I would expect Configuration 2 to run successfully, because the total sequence length (98304) is a multiple of the sequence parallelism dimension (3).
Generalizing to more sequence parallelism dimensions, I find that:
- Setting
SEQ_PAR_DIM to either 2 or 4 runs successfully.
- Setting
SEQ_PAR_DIM to either 3 or 6 crashes with a partitioning error.
I'm trying to run Ring Attention on a machine with 6 A100 GPUs, and I'm finding that when I try to set the sequence parallelism dimension to anything other than a power of 2, the process crashes with a JAX partitioning error.
I'd be grateful for any insight into whether or not I'm doing something wrong in the way I'm invoking the training script, and for any advice on how to work around this issue.
Steps to Reproduce
Consider the following script for invoking
llamabpt.train:For Configuration 1, where the sequence parallelism dimension is
4, the training script runs as expected without errors.However, when I uncomment Configuration 2, where the sequence parallelism dimension is
3, the training script crashes with the following error:The error occurs during the first call to
sharded_init_fn.I would expect Configuration 2 to run successfully, because the total sequence length (
98304) is a multiple of the sequence parallelism dimension (3).Generalizing to more sequence parallelism dimensions, I find that:
SEQ_PAR_DIMto either2or4runs successfully.SEQ_PAR_DIMto either3or6crashes with a partitioning error.