Skip to content

Use torch's AdamW in examples/pytorch/BertNewsClassification/bert_classification.py#8730

Merged
harupy merged 1 commit intomlflow:masterfrom
harupy:fix-bert-classification
Jun 14, 2023
Merged

Use torch's AdamW in examples/pytorch/BertNewsClassification/bert_classification.py#8730
harupy merged 1 commit intomlflow:masterfrom
harupy:fix-bert-classification

Conversation

@harupy
Copy link
Member

@harupy harupy commented Jun 14, 2023

Related Issues/PRs

#xxx

What changes are proposed in this pull request?

examples/pytorch/BertNewsClassification/bert_classification.py started failing last week with the following error. I think something changed in transformers 4.30.x. Locally confirmed torch.optim.AdamW works fine.

https://pypi.org/project/transformers/4.30.0

Traceback (most recent call last):
  File "bert_classification.py", line 382, in <module>
    cli_main()
  File "bert_classification.py", line 377, in cli_main
    cli.trainer.fit(cli.model, datamodule=cli.datamodule)
  File "/home/runner/.mlflow/envs/mlflow-1e11c0a8ece3e3c7125d0d9ca5d64ab48e2f35dc/lib/python3.8/site-packages/mlflow/utils/autologging_utils/safety.py", line 554, in safe_patch_function
    patch_function(call_original, *args, **kwargs)
  File "/home/runner/.mlflow/envs/mlflow-1e11c0a8ece3e3c7125d0d9ca5d64ab48e2f35dc/lib/python3.8/site-packages/mlflow/utils/autologging_utils/safety.py", line 254, in patch_with_managed_run
    result = patch_function(original, *args, **kwargs)
  File "/home/runner/.mlflow/envs/mlflow-1e11c0a8ece3e3c7125d0d9ca5d64ab48e2f35dc/lib/python3.8/site-packages/mlflow/pytorch/_lightning_autolog.py", line 389, in patched_fit
    result = original(self, *args, **kwargs)
  File "/home/runner/.mlflow/envs/mlflow-1e11c0a8ece3e3c7125d0d9ca5d64ab48e2f35dc/lib/python3.8/site-packages/mlflow/utils/autologging_utils/safety.py", line 535, in call_original
    return call_original_fn_with_event_logging(_original_fn, og_args, og_kwargs)
  File "/home/runner/.mlflow/envs/mlflow-1e11c0a8ece3e3c7125d0d9ca5d64ab48e2f35dc/lib/python3.8/site-packages/mlflow/utils/autologging_utils/safety.py", line 470, in call_original_fn_with_event_logging
    original_fn_result = original_fn(*og_args, **og_kwargs)
  File "/home/runner/.mlflow/envs/mlflow-1e11c0a8ece3e3c7125d0d9ca5d64ab48e2f35dc/lib/python3.8/site-packages/mlflow/utils/autologging_utils/safety.py", line 532, in _original_fn
    original_result = original(*_og_args, **_og_kwargs)
  File "/home/runner/.mlflow/envs/mlflow-1e11c0a8ece3e3c7125d0d9ca5d64ab48e2f35dc/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 520, in fit
    call._call_and_handle_interrupt(
  File "/home/runner/.mlflow/envs/mlflow-1e11c0a8ece3e3c7125d0d9ca5d64ab48e2f35dc/lib/python3.8/site-packages/lightning/pytorch/trainer/call.py", line 44, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/home/runner/.mlflow/envs/mlflow-1e11c0a8ece3e3c7125d0d9ca5d64ab48e2f35dc/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 559, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/home/runner/.mlflow/envs/mlflow-1e11c0a8ece3e3c7125d0d9ca5d64ab48e2f35dc/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 935, in _run
    results = self._run_stage()
  File "/home/runner/.mlflow/envs/mlflow-1e11c0a8ece3e3c7125d0d9ca5d64ab48e2f35dc/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 978, in _run_stage
    self.fit_loop.run()
  File "/home/runner/.mlflow/envs/mlflow-1e11c0a8ece3e3c7125d0d9ca5d64ab48e2f35dc/lib/python3.8/site-packages/lightning/pytorch/loops/fit_loop.py", line 201, in run
    self.advance()
  File "/home/runner/.mlflow/envs/mlflow-1e11c0a8ece3e3c7125d0d9ca5d64ab48e2f35dc/lib/python3.8/site-packages/lightning/pytorch/loops/fit_loop.py", line 354, in advance
    self.epoch_loop.run(self._data_fetcher)
  File "/home/runner/.mlflow/envs/mlflow-1e11c0a8ece3e3c7125d0d9ca5d64ab48e2f35dc/lib/python3.8/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 133, in run
    self.advance(data_fetcher)
  File "/home/runner/.mlflow/envs/mlflow-1e11c0a8ece3e3c7125d0d9ca5d64ab48e2f35dc/lib/python3.8/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 218, in advance
    batch_output = self.automatic_optimization.run(trainer.optimizers[0], kwargs)
  File "/home/runner/.mlflow/envs/mlflow-1e11c0a8ece3e3c7125d0d9ca5d64ab48e2f35dc/lib/python3.8/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 185, in run
    self._optimizer_step(kwargs.get("batch_idx", 0), closure)
  File "/home/runner/.mlflow/envs/mlflow-1e11c0a8ece3e3c7125d0d9ca5d64ab48e2f35dc/lib/python3.8/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 261, in _optimizer_step
    call._call_lightning_module_hook(
  File "/home/runner/.mlflow/envs/mlflow-1e11c0a8ece3e3c7125d0d9ca5d64ab48e2f35dc/lib/python3.8/site-packages/lightning/pytorch/trainer/call.py", line 142, in _call_lightning_module_hook
    output = fn(*args, **kwargs)
  File "/home/runner/.mlflow/envs/mlflow-1e11c0a8ece3e3c7125d0d9ca5d64ab48e2f35dc/lib/python3.8/site-packages/lightning/pytorch/core/module.py", line 1266, in optimizer_step
    optimizer.step(closure=optimizer_closure)
  File "/home/runner/.mlflow/envs/mlflow-1e11c0a8ece3e3c7125d0d9ca5d64ab48e2f35dc/lib/python3.8/site-packages/lightning/pytorch/core/optimizer.py", line 158, in step
    step_output = self._strategy.optimizer_step(self._optimizer, closure, **kwargs)
  File "/home/runner/.mlflow/envs/mlflow-1e11c0a8ece3e3c7125d0d9ca5d64ab48e2f35dc/lib/python3.8/site-packages/lightning/pytorch/strategies/strategy.py", line 224, in optimizer_step
    return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs)
  File "/home/runner/.mlflow/envs/mlflow-1e11c0a8ece3e3c7125d0d9ca5d64ab48e2f35dc/lib/python3.8/site-packages/lightning/pytorch/plugins/precision/precision_plugin.py", line 114, in optimizer_step
    return optimizer.step(closure=closure, **kwargs)
  File "/home/runner/.mlflow/envs/mlflow-1e11c0a8ece3e3c7125d0d9ca5d64ab48e2f35dc/lib/python3.8/site-packages/torch/optim/optimizer.py", line 280, in wrapper
    out = func(*args, **kwargs)
  File "/home/runner/.mlflow/envs/mlflow-1e11c0a8ece3e3c7125d0d9ca5d64ab48e2f35dc/lib/python3.8/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/runner/.mlflow/envs/mlflow-1e11c0a8ece3e3c7125d0d9ca5d64ab48e2f35dc/lib/python3.8/site-packages/transformers/optimization.py", line 439, in step
    loss = closure()
  File "/home/runner/.mlflow/envs/mlflow-1e11c0a8ece3e3c7125d0d9ca5d64ab48e2f35dc/lib/python3.8/site-packages/lightning/pytorch/plugins/precision/precision_plugin.py", line 101, in _wrap_closure
    closure_result = closure()
  File "/home/runner/.mlflow/envs/mlflow-1e11c0a8ece3e3c7125d0d9ca5d64ab48e2f35dc/lib/python3.8/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 140, in __call__
    self._result = self.closure(*args, **kwargs)
  File "/home/runner/.mlflow/envs/mlflow-1e11c0a8ece3e3c7125d0d9ca5d64ab48e2f35dc/lib/python3.8/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 135, in closure
    self._backward_fn(step_output.closure_loss)
  File "/home/runner/.mlflow/envs/mlflow-1e11c0a8ece3e3c7125d0d9ca5d64ab48e2f35dc/lib/python3.8/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 233, in backward_fn
    call._call_strategy_hook(self.trainer, "backward", loss, optimizer)
  File "/home/runner/.mlflow/envs/mlflow-1e11c0a8ece3e3c7125d0d9ca5d64ab48e2f35dc/lib/python3.8/site-packages/lightning/pytorch/trainer/call.py", line 288, in _call_strategy_hook
    output = fn(*args, **kwargs)
  File "/home/runner/.mlflow/envs/mlflow-1e11c0a8ece3e3c7125d0d9ca5d64ab48e2f35dc/lib/python3.8/site-packages/lightning/pytorch/strategies/strategy.py", line 199, in backward
    self.precision_plugin.backward(closure_loss, self.lightning_module, optimizer, *args, **kwargs)
  File "/home/runner/.mlflow/envs/mlflow-1e11c0a8ece3e3c7125d0d9ca5d64ab48e2f35dc/lib/python3.8/site-packages/lightning/pytorch/plugins/precision/precision_plugin.py", line 67, in backward
    model.backward(tensor, *args, **kwargs)
  File "/home/runner/.mlflow/envs/mlflow-1e11c0a8ece3e3c7125d0d9ca5d64ab48e2f35dc/lib/python3.8/site-packages/lightning/pytorch/core/module.py", line 1055, in backward
    loss.backward(*args, **kwargs)
  File "/home/runner/.mlflow/envs/mlflow-1e11c0a8ece3e3c7125d0d9ca5d64ab48e2f35dc/lib/python3.8/site-packages/torch/_tensor.py", line 487, in backward
    torch.autograd.backward(
  File "/home/runner/.mlflow/envs/mlflow-1e11c0a8ece3e3c7125d0d9ca5d64ab48e2f35dc/lib/python3.8/site-packages/torch/autograd/__init__.py", line 200, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass

https://github.com/mlflow/mlflow/actions/runs/5230096849/jobs/9443457368

How is this patch tested?

  • Existing unit/integration tests
  • New unit/integration tests
  • Manual tests (describe details, including test results, below)

Does this PR change the documentation?

  • No. You can skip the rest of this section.
  • Yes. Make sure the changed pages / sections render correctly in the documentation preview.

Release Notes

Is this a user-facing change?

  • No. You can skip the rest of this section.
  • Yes. Give a description of this change to be included in the release notes for MLflow users.

(Details in 1-2 sentences. You can just refer to another PR with a description if this PR is part of a larger change.)

What component(s), interfaces, languages, and integrations does this PR affect?

Components

  • area/artifacts: Artifact stores and artifact logging
  • area/build: Build and test infrastructure for MLflow
  • area/docs: MLflow documentation pages
  • area/examples: Example code
  • area/model-registry: Model Registry service, APIs, and the fluent client calls for Model Registry
  • area/models: MLmodel format, model serialization/deserialization, flavors
  • area/recipes: Recipes, Recipe APIs, Recipe configs, Recipe Templates
  • area/projects: MLproject format, project running backends
  • area/scoring: MLflow Model server, model deployment tools, Spark UDFs
  • area/server-infra: MLflow Tracking server backend
  • area/tracking: Tracking Service, tracking client APIs, autologging

Interface

  • area/uiux: Front-end, user experience, plotting, JavaScript, JavaScript dev server
  • area/docker: Docker use across MLflow's components, such as MLflow Projects and MLflow Models
  • area/sqlalchemy: Use of SQLAlchemy in the Tracking Service or Model Registry
  • area/windows: Windows support

Language

  • language/r: R APIs and clients
  • language/java: Java APIs and clients
  • language/new: Proposals for new client languages

Integrations

  • integrations/azure: Azure and Azure ML integrations
  • integrations/sagemaker: SageMaker integrations
  • integrations/databricks: Databricks integrations

How should the PR be classified in the release notes? Choose one:

  • rn/breaking-change - The PR will be mentioned in the "Breaking Changes" section
  • rn/none - No description will be included. The PR will be mentioned only by the PR number in the "Small Bugfixes and Documentation Updates" section
  • rn/feature - A new user-facing feature worth mentioning in the release notes
  • rn/bug-fix - A user-facing bug fix worth mentioning in the release notes
  • rn/documentation - A user-facing documentation change worth mentioning in the release notes

Signed-off-by: harupy <[email protected]>
@harupy harupy requested a review from serena-ruan June 14, 2023 07:37
@github-actions github-actions bot added the rn/none List under Small Changes in Changelogs. label Jun 14, 2023
@mlflow-automation
Copy link
Contributor

Documentation preview for 378948d will be available here when this CircleCI job completes successfully.

More info

@harupy harupy enabled auto-merge (squash) June 14, 2023 09:26
@harupy harupy requested a review from BenWilson2 June 14, 2023 11:50
Copy link
Member

@BenWilson2 BenWilson2 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@harupy harupy merged commit 826ccc2 into mlflow:master Jun 14, 2023
BenWilson2 pushed a commit to BenWilson2/mlflow that referenced this pull request Jun 16, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

rn/none List under Small Changes in Changelogs.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants