Skip to content

JAX partitioning error when attempting to run with sequence parallelism factor not a power of 2 #9

@exists-forall

Description

@exists-forall

I'm trying to run Ring Attention on a machine with 6 A100 GPUs, and I'm finding that when I try to set the sequence parallelism dimension to anything other than a power of 2, the process crashes with a JAX partitioning error.

I'd be grateful for any insight into whether or not I'm doing something wrong in the way I'm invoking the training script, and for any advice on how to work around this issue.

Steps to Reproduce

Consider the following script for invoking llamabpt.train:

#########################################################
### Configuration 1 (runs successfully)               ###
#########################################################
export CUDA_VISIBLE_DEVICES="0,1,2,3"
SEQ_PAR_DIM=4
MAX_SEQ_LEN=131072
#########################################################

#########################################################
### Configuration 2 (CRASHES with partitioning error) ###
#########################################################
# export CUDA_VISIBLE_DEVICES="0,1,2"
# SEQ_PAR_DIM=3
# MAX_SEQ_LEN=98304
#########################################################

python3 -m llamabpt.train \
  --mesh_dim="1,1,1,${SEQ_PAR_DIM}" \
  --dtype=bf16 \
  --load_llama_config=1b \
  --update_llama_config="{'max_sequence_length': ${MAX_SEQ_LEN}, 'scan_attention': True, 'scan_query_chunk_size': 2048, 'scan_key_chunk_size': 4096, 'remat_attention': 'nothing_saveable', 'scan_mlp': True, 'scan_mlp_chunk_size': 2048, 'remat_mlp': 'nothing_saveable', 'remat_block': 'nothing_saveable', 'scan_layers': True, 'attention_type': 'ring_blockwise', 'param_scan_axis': 0, 'mesh_dim': '1,1,1,${SEQ_PAR_DIM}'}" \
  --total_steps=2 \
  --log_freq=1 \
  --save_model_freq=0 \
  --save_milestone_freq=1000 \
  --tokenizer.vocab_file="${TRAIN_DATA_PATH}" \
  --optimizer.type=adamw \
  --optimizer.adamw_optimizer.weight_decay=0.1 \
  --optimizer.adamw_optimizer.lr=1.5e-4 \
  --optimizer.adamw_optimizer.end_lr=1.5e-5 \
  --optimizer.adamw_optimizer.lr_warmup_steps=1 \
  --optimizer.adamw_optimizer.lr_decay_steps=10 \
  --train_dataset.type=json \
  --train_dataset.text_processor.fields=text \
  --train_dataset.json_dataset.path="${TOKENIZER_PATH}" \
  --train_dataset.json_dataset.seq_length=${MAX_SEQ_LEN} \
  --train_dataset.json_dataset.batch_size=1 \
  --train_dataset.json_dataset.tokenizer_processes=16

For Configuration 1, where the sequence parallelism dimension is 4, the training script runs as expected without errors.

However, when I uncomment Configuration 2, where the sequence parallelism dimension is 3, the training script crashes with the following error:

ValueError: One of pjit outputs with pytree key path .params['params']['lm_head']['kernel'] was given the sharding of NamedSharding(mesh={'dp': 1, 'fsdp': 1, 'tp': 1, 'sp': 3}, spec=PartitionSpec(('fsdp', 'sp'), 'tp')), which implies that the global size of its dimension 0 should be divisible by 3, but it is equal to 2048 (full shape: (2048, 32000))

The error occurs during the first call to sharded_init_fn.

I would expect Configuration 2 to run successfully, because the total sequence length (98304) is a multiple of the sequence parallelism dimension (3).

Generalizing to more sequence parallelism dimensions, I find that:

  • Setting SEQ_PAR_DIM to either 2 or 4 runs successfully.
  • Setting SEQ_PAR_DIM to either 3 or 6 crashes with a partitioning error.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions