Fix issues for saving checkpointing steps#9891
Fix issues for saving checkpointing steps#9891leisuzz wants to merge 0 commit intohuggingface:mainfrom
Conversation
|
@sayakpaul Please take a look at this PR, thanks for your help! |
sayakpaul
left a comment
There was a problem hiding this comment.
Thanks for your PR. Can you please modify a single file first and discuss the changes first?
|
When I did the deambooth flux without Lora, and save the check pointing. It stuck for a while and break. So I think they all need these modifications. I can only do the flux ones if you want |
|
Yeah let's change a single file first and then we can discuss the changes first. |
|
Sure |
|
@sayakpaul I already changed the modifications only on FLUX models |
| if global_step % args.checkpointing_steps == 0: | ||
| # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` |
There was a problem hiding this comment.
Well, there is a better way to handle it:
if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED:
sayakpaul
left a comment
There was a problem hiding this comment.
Thaks for your PR!
You can refer to the following scripts:
- https://github.com/a-r-r-o-w/cogvideox-factory/blob/main/training/cogvideox_image_to_video_lora.py
- https://github.com/a-r-r-o-w/cogvideox-factory/blob/main/training/cogvideox_text_to_video_sft.py
To see how we handle saving and loading from checkpoints when using DeepSpeed.
Search for DistributedType.DEEPSPEED.
|
@sayakpaul I've changed it based on the reference |
sayakpaul
left a comment
There was a problem hiding this comment.
Thanks! Left more comments. LMK if they're clear.
|
|
||
| model.load_state_dict(load_model.state_dict()) | ||
| except Exception: | ||
| elif isinstance(unwrap_model(model), (CLIPTextModelWithProjection, T5EncoderModel)): |
There was a problem hiding this comment.
We don't support fine-tuning the T5 model. So, this seems wrong. It should just be CLIPTextModelWithProjection, no?
| try: | ||
| load_model = T5EncoderModel.from_pretrained(input_dir, subfolder="text_encoder_2") | ||
| model(**load_model.config) | ||
| model.load_state_dict(load_model.state_dict()) | ||
| except Exception: | ||
| raise ValueError(f"Couldn't load the model of type: ({type(model)}).") | ||
| else: | ||
| raise ValueError(f"Unsupported model found: {type(model)=}") |
| try: | ||
| load_model = CLIPTextModelWithProjection.from_pretrained(input_dir, subfolder="text_encoder") | ||
| model(**load_model.config) | ||
| if not accelerator.distributed_type == DistributedType.DEEPSPEED: |
There was a problem hiding this comment.
We also need to handle the case when we're actually doing DeepSpeed training. Similar to:
https://github.com/a-r-r-o-w/cogvideox-factory/blob/d63a826f37758eccf226710f94f6c3a4d4ee7a25/training/cogvideox_text_to_video_sft.py#L385
|
|
||
| while len(models) > 0: | ||
| model = models.pop() | ||
| if not accelerator.distributed_type == DistributedType.DEEPSPEED: |
There was a problem hiding this comment.
Same. We also need to handle the case when we are not doing DeepSpeed training. Reference:
https://github.com/a-r-r-o-w/cogvideox-factory/blob/d63a826f37758eccf226710f94f6c3a4d4ee7a25/training/cogvideox_text_to_video_lora.py#L396
| if weights: | ||
| weights.pop() | ||
|
|
||
| def load_model_hook(models, input_dir): |
There was a problem hiding this comment.
Seems like we're not handling the loading case appropriately here. I repeated this multiple times now but please refer to the changes here to get an idea of what is required.
In summary, we're not dealing with the changes required to load the state dict in the models being trained when DeepSpeed is enabled.
|
Gentle ping @leisuzz |
|
@sayakpaul I tested yesterday, with the DistributedType function, the checkpointing can be saved. And the if weights: condition will remove the issue cause I encountered once, that's why I added this |
Yeah but we have not addressed the comments fully yet. Specifically, we haven't addressed what's needed to enable loading of state dicts properly when using DeepSpeed. We need to address that. |
|
I will check that |
What does this PR do?
These modification can help to save the checkpoint steps while training. Otherwise it will just stuck for too long and timeout.
Fixes get stuck when save_state using DeepSpeed backend under training train_text_to_image_lora #2606
Bug fix for weight pop from empty list
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.