Skip to content

Fix issues for saving checkpointing steps#9891

Closed
leisuzz wants to merge 0 commit intohuggingface:mainfrom
leisuzz:main
Closed

Fix issues for saving checkpointing steps#9891
leisuzz wants to merge 0 commit intohuggingface:mainfrom
leisuzz:main

Conversation

@leisuzz
Copy link
Copy Markdown
Contributor

@leisuzz leisuzz commented Nov 8, 2024

What does this PR do?

  1. These modification can help to save the checkpoint steps while training. Otherwise it will just stuck for too long and timeout.
    Fixes get stuck when save_state using DeepSpeed backend under training train_text_to_image_lora #2606

  2. Bug fix for weight pop from empty list

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@leisuzz
Copy link
Copy Markdown
Contributor Author

leisuzz commented Nov 13, 2024

@sayakpaul Please take a look at this PR, thanks for your help!

Copy link
Copy Markdown
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your PR. Can you please modify a single file first and discuss the changes first?

@leisuzz
Copy link
Copy Markdown
Contributor Author

leisuzz commented Nov 13, 2024

When I did the deambooth flux without Lora, and save the check pointing. It stuck for a while and break. So I think they all need these modifications. I can only do the flux ones if you want

@sayakpaul
Copy link
Copy Markdown
Member

Yeah let's change a single file first and then we can discuss the changes first.

@leisuzz
Copy link
Copy Markdown
Contributor Author

leisuzz commented Nov 13, 2024

Sure

@leisuzz
Copy link
Copy Markdown
Contributor Author

leisuzz commented Nov 14, 2024

@sayakpaul I already changed the modifications only on FLUX models

Comment on lines +1670 to +1671
if global_step % args.checkpointing_steps == 0:
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, there is a better way to handle it:

if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED:

Copy link
Copy Markdown
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thaks for your PR!

You can refer to the following scripts:

To see how we handle saving and loading from checkpoints when using DeepSpeed.

Search for DistributedType.DEEPSPEED.

@leisuzz
Copy link
Copy Markdown
Contributor Author

leisuzz commented Nov 15, 2024

@sayakpaul I've changed it based on the reference

Copy link
Copy Markdown
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Left more comments. LMK if they're clear.


model.load_state_dict(load_model.state_dict())
except Exception:
elif isinstance(unwrap_model(model), (CLIPTextModelWithProjection, T5EncoderModel)):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't support fine-tuning the T5 model. So, this seems wrong. It should just be CLIPTextModelWithProjection, no?

Comment on lines +1211 to +1218
try:
load_model = T5EncoderModel.from_pretrained(input_dir, subfolder="text_encoder_2")
model(**load_model.config)
model.load_state_dict(load_model.state_dict())
except Exception:
raise ValueError(f"Couldn't load the model of type: ({type(model)}).")
else:
raise ValueError(f"Unsupported model found: {type(model)=}")
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same for this.

try:
load_model = CLIPTextModelWithProjection.from_pretrained(input_dir, subfolder="text_encoder")
model(**load_model.config)
if not accelerator.distributed_type == DistributedType.DEEPSPEED:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We also need to handle the case when we're actually doing DeepSpeed training. Similar to:
https://github.com/a-r-r-o-w/cogvideox-factory/blob/d63a826f37758eccf226710f94f6c3a4d4ee7a25/training/cogvideox_text_to_video_sft.py#L385


while len(models) > 0:
model = models.pop()
if not accelerator.distributed_type == DistributedType.DEEPSPEED:
Copy link
Copy Markdown
Member

@sayakpaul sayakpaul Nov 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same. We also need to handle the case when we are not doing DeepSpeed training. Reference:
https://github.com/a-r-r-o-w/cogvideox-factory/blob/d63a826f37758eccf226710f94f6c3a4d4ee7a25/training/cogvideox_text_to_video_lora.py#L396

if weights:
weights.pop()

def load_model_hook(models, input_dir):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like we're not handling the loading case appropriately here. I repeated this multiple times now but please refer to the changes here to get an idea of what is required.

In summary, we're not dealing with the changes required to load the state dict in the models being trained when DeepSpeed is enabled.

@sayakpaul
Copy link
Copy Markdown
Member

Gentle ping @leisuzz

@leisuzz
Copy link
Copy Markdown
Contributor Author

leisuzz commented Nov 26, 2024

@sayakpaul I tested yesterday, with the DistributedType function, the checkpointing can be saved. And the if weights: condition will remove the issue cause I encountered once, that's why I added this

@sayakpaul
Copy link
Copy Markdown
Member

@sayakpaul I tested yesterday, with the DistributedType function, the checkpointing can be saved. And the if weights: condition will remove the issue cause I encountered once, that's why I added this

Yeah but we have not addressed the comments fully yet. Specifically, we haven't addressed what's needed to enable loading of state dicts properly when using DeepSpeed. We need to address that.

@leisuzz
Copy link
Copy Markdown
Contributor Author

leisuzz commented Nov 26, 2024

I will check that

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

get stuck when save_state using DeepSpeed backend under training train_text_to_image_lora

2 participants