pytorch-lightning 2.1.0
Lightning 2.1: Train Bigger, Better, Faster

latest releases: 2.0.0...2.1.0
11 months ago

Lightning AI is excited to announce the release of Lightning 2.1 ⚡ It's the culmination of work from 79 contributors who have worked on features, bug-fixes, and documentation for a total of over 750+ commits since v2.0.

The theme of 2.1 is "bigger, better, faster": Bigger because training large multi-billion parameter models has gotten even more efficient thanks to FSDP, efficient initialization and sharded checkpointing improvements, better because it's easier than ever to scale models without making substantial code changes or installing third-party packages and faster because it leverages the latest hardware features to speed up training in low-bit precision thanks to new precision plugins like bitsandbytes and transformer engine.
And of course, as the name implies, this release fully leverages the latest features in PyTorch 2.1 🎉

  • Highlights
    • Improvements To Large-Scale Training With FSDP
    • True Half-Precision
    • Bitsandbytes Quantization
    • Transformer Engine
    • Lightning on TPU Goes Brrr
    • Granular Control Over Checkpoints in Fabric
  • Backward Incompatible Changes
    • PyTorch Lightning
    • Lightning Fabric
  • Full Changelog
    • PyTorch Lightning
    • Lightning Fabric
    • Lightning App
  • Contributors


Improvements To Large-Scale Training With FSDP

The FSDP strategy for training large billion-parameter models gets substantial improvements and new features in Lightning 2.1, both in Trainer and Fabric (in case you didn't know, Fabric is the latest addition to the Lightning family of tools to scale models without the boilerplate code).
FSDP is now more user-friendly to configure, has memory management and speed improvements, and we have a brand new end-to-end user guide with best practices (Trainer, Fabric).

Efficient Saving and Loading of Large Checkpoints

When training large billion-parameter models with FSDP, saving and resuming training, or even just loading model parameters for finetuning can be challenging, as users are are often plagued by out-of-memory errors and speed bottlenecks.

In 2.1, we made several improvements. Starting with saving checkpoints, we added support for distributed/sharded checkpoints, enabled through the setting state_dict_type in the strategy (#18364, #18358):


import lightning as L
from lightning.pytorch.strategies import FSDPStrategy

# Default used by the strategy
strategy = FSDPStrategy(state_dict_type="full")

# Enable saving distributed checkpoints
strategy = FSDPStrategy(state_dict_type="sharded")

trainer = L.Trainer(strategy=strategy, ...)


import lightning as L
from lightning.fabric.strategies import FSDPStrategy

# Saving distributed checkpoints is the default
strategy = FSDPStrategy(state_dict_type="sharded")

# Save consolidated (single file) checkpoints
strategy = FSDPStrategy(state_dict_type="full")

fabric = L.Fabric(strategy=strategy, ...)

Distributed checkpoints are the fastest and most memory efficient way to save the state of very large models.
The distributed checkpoint format also makes it efficient to load these checkpoints back for resuming training in parallel, and it reduces the impact on CPU memory usage significantly. Furthermore, we've also introduced lazy-loading for non-distributed checkpoints (#18150, #18379), which greatly reduces the impact on CPU memory usage when loading a consolidated (single-file) checkpoint (e.g. for finetuning). Learn more about these features in our FSDP guides (Trainer, Fabric).

Fast and Memory-Optimized Initialization

A major challenge that users face when working with large models such as LLMs is dealing with the extreme memory requirements. Even something as simple as instantiating a model becomes non-trivial if the model is so large it won't fit in a single GPU or even a single machine. In Lightning 2.1, we are introducing empty-weights initialization through the Fabric.init_module() (#17462, #17627) and Trainer.init_module()/LightningModule.configure_model() (#18004, #18004, #18385) methods:


import lightning as L

class MyModel(L.LightningModule):
    def __init__(self):
        # Delay initialization of model to `configure_model()`

    def configure_model(self):
        # Model initialized in correct precision and weights on meta-device
        self.model = ...


trainer = L.Trainer(strategy="fsdp", ...)


import lightning as L

fabric = L.Fabric(strategy="fsdp", ...)

# Model initialized in correct precision and weights on meta-device
with fabric.init_module(empty_weights=True):
    model = ...

# You can also initialize buffers and tensors directly on device and dtype
with fabric.init_tensor():
    x = torch.randn(4, 128)

# Materialization and sharding of model happens inside here
model = fabric.setup(model)

Read more about this new feature and its other benefits in our docs (Trainer, Fabric).

User-Friendly Configuration

We made it super easy to configure the sharding- and activation-checkpointing policy when you want to auto-wrap particular layers of your model for advanced control (#18045, #18084).

  import lightning as L
  from lightning.pytorch.strategies import FSDPStrategy
- from torch.distributed.fsdp.wrap import ModuleWrapPolicy

- strategy = FSDPStrategy(auto_wrap_policy=ModuleWrapPolicy({MyTransformerBlock}))
+ strategy = FSDPStrategy(auto_wrap_policy={MyTransformerBlock})
  trainer = L.Trainer(strategy=strategy, ...)

Furthermore, the sharding strategy can now be conveniently set with a string value (#18087):

  import lightning as L
  from lightning.pytorch.strategies import FSDPStrategy
- from torch.distributed.fsdp.fully_sharded_data_parallel import ShardingStrategy

- strategy = FSDPStrategy(sharding_strategy=ShardingStrategy.SHARD_GRAD_OP)
+ strategy = FSDPStrategy(sharding_strategy="SHARD_GRAD_OP")
  trainer = L.Trainer(strategy=strategy, ...)

You no longer need to remember the long PyTorch imports! Fabric also supports all these improvements shown above.

True Half-Precision

Lightning now supports true half-precision for training and inference with all built-in strategies (#18193, #18217, #18213, #18219). With this setting, the memory required to store the model weights is only half of what is normally needed when running with float32. In addition, you get the same speed benefits as mixed precision training (precision="16-mixed") has:

import lightning as L

# default
trainer = L.Trainer(precision="32-true")

# train with model weights in `torch.float16`
trainer = L.Trainer(precision="16-true")

# train with model weights in `torch.bfloat16`
# (if hardware supports it)
trainer = L.Trainer(precision="bf16-true")

The same settings are also available in Fabric! We recommend to try bfloat16 training (precision="bf16-true") as it is often more numerically stable than regular 16-bit precision (precision="16-true").

Bitsandbytes Quantization

With the new Bitsandbytes precision plugin #18655, you can now quantize your model for significant memory savings during training, finetuning, or inference with a selection of several state-of-the-art quantization algorithms (int8, fp4, nf4 and more). For the first time, Trainer and Fabric make bitsandbytes easy to use for general models.


import lightning as L
from lightning.pytorch.plugins import BitsandbytesPrecisionPlugin

# this will pick out the compute dtype automatically, by default `bfloat16`
precision = BitsandbytesPrecisionPlugin("nf4-dq")
trainer = L.Trainer(plugins=precision)


import lightning as L
from lightning.fabric.plugins import BitsandbytesPrecision

# this will pick out the compute dtype automatically, by default `bfloat16`
precision = BitsandbytesPrecision("nf4-dq")
trainer = L.Fabric(plugins=precision)

Learn more!

Transformer Engine

The Transformer Engine by NVIDIA is a library for accelerating transformer layers on the new Hopper (H100) generation of GPUs. With the integration in Lightning Trainer and Fabric (#17597, #18459), you have easy access to the 8-bit mixed precision for significant speed ups:


import lightning as L

# Select 8-bit mixed precision via TransformerEngine, with model weights in float16
trainer = L.Trainer(precision="transformer-engine-float16")


import lightning as L

# Select 8-bit mixed precision via TransformerEngine, with model weights in float16
fabric = L.Fabric(precision="transformer-engine-float16")

More configuration options are available through the respective plugins in Trainer and Fabric.

Lightning on TPU Goes Brrr

Lightning 2.1 runs on the latest generation of TPU hardware on Google Cloud! TPU-v4 and TPU-v5 (#17227) are now fully supported both in Fabric and Trainer and run using the new PjRT runtime by default (#17352). PjRT is the runtime used by Jax and has shown an average improvement of 35% on benchmarks.


import lightning as L

trainer = L.Trainer(accelerator="tpu", devices=8)
model = MyModel()  # uses PjRT if available


import lightning as L

def train(fabric):

fabric = L.Fabric(accelerator="tpu")
fabric.launch(train)  # uses PjRT if available

And what's even more exciting, you can now scale massive multi-billion parameter models on TPUs using FSDP (#17421).

import lightning as L
from lightning.fabric.strategies import XLAFSDPStrategy

strategy = XLAFSDPStrategy(
    # Most arguments from the PyTorch native FSDP strategy are also available here!
fabric = L.Fabric(devices=8, strategy=strategy)

You can find a full end-to-end finetuning example script in our Lit-GPT repository. The new XLA-FSDP strategy is experimental and currently only available in Fabric. Support in the Trainer will follow in the future.

Granular Control Over Checkpoints in Fabric

Several improvements for checkpoint saving and loading have landed in Fabric, enabling more fine-grained control over what is saved/loaded while reducing boilerplate code:

  1. There is a new Fabric.load_raw() method with which you can load model- or optimizer state-dicts saved externally by a non-Fabric application (e.g., raw PyTorch) (#18049)

    import lightning as L
    fabric = L.Fabric()
    model = MyModel()
    # A model weights file saved by your friend who doesn't use Fabric
    fabric.load_raw("path/to/", model)
    # Equivalent to this:
    # model.load_state_dict(torch.load("path/to/"))
  2. A new parameter Fabric.load(..., strict=True|False) to disable strict loading (#17645)

    import lightning as L
    fabric = L.Fabric()
    model = MyModel()
    state = {"model": model}
    # strict loading is the default
    fabric.load("path/to/checkpoint.ckpt", state, strict=True)
    # disable strict loading
    fabric.load("path/to/checkpoint.ckpt", state, strict=False)
  3. A new parameter, filter=...) that enables you to exclude certain parameters of your model without writing boilerplate code for it (#17845)

    import lightning as L
    fabric = L.Fabric()
    model, optimizer = ...
    state = {"model": model, "optimizer": optimizer, "foo": 123}
    # save only the weights that match a pattern
    filter = {"model": lambda k, v: "weight" in k}"path/to/checkpoint.ckpt", state, filter=filter)

You can read more about the new options in our checkpoint guide.

Backward Incompatible Changes

The release of PyTorch Lightning 2.0 was a big step into a new chapter: It brought a more polished API and removed a lot of legacy code and outdated as well as experimental features, at the cost of a long list of breaking changes resulting in more work needed than usual to upgrade from 1.9 to 2.0. Moving forward, we promised to maintain full backward compatibility of our public core APIs to guarantee a smooth upgrade experience for everyone, and with 2.1 we are happy to deliver on this promise. A few exceptions were made in places where the change was justified if it significantly improves the user experience, improves performance, or fixes the correctness of a feature. These changes will likely not impact most users.

PyTorch Lightning

TPU/XLA Changes

When selecting device indices via devices=[i], the Trainer now selects the i-th TPU core (0-based, previously it was 1-based) (#17227)


# Selects the first TPU core (1-based index)
trainer = Trainer(accelerator="tpu", devices=[1])


# Selects the second TPU core (0-based index)
trainer = Trainer(accelerator="tpu", devices=[1])

Multi-GPU in Jupyter Notebooks

Due to lack of reliability, Trainer now only runs on one GPU instead of all GPUs in a Jupyter notebook if devices="auto" (default) (#18291)


import lightning as L

# In Jupyter notebooks, this would select all available GPUs (DDP)
trainer = L.Trainer(accelerator="cuda", devices="auto")


# In Jupyter notebooks, this now selects only one GPU (the first)
trainer = L.Trainer(accelerator="cuda", devices="auto")

# You can still explicitly select multiple
trainer = L.Trainer(accelerator="cuda", devices=8)

Device Access in Setup Hook

  • During LightningModule.setup(), the self.device now returns the device the module will be placed on instead of cpu (#18021)


def setup(self, stage):
    # CPU regardless of the accelerator used


def setup(self, stage):
    # CPU/CUDA/MPS/XLA depending on accelerator

Miscellaneous Changes

  • self.loged tensors are now kept in the original device to reduce unnecessary host-to-device synchronizations (#17334)
  • The FSDPStrategy now loads checkpoints after the configure_model/configure_sharded_model hook (#18358)
  • The FSDPStrategy.load_optimizer_state_dict and FSDPStrategy.load_model_state_dict are a no-op now (#18358)
  • Removed experimental support for torchdistx due to a lack of project maintenance (#17995)
  • Dropped support for PyTorch 3.11 (#18691)

Lightning Fabric

We thank the community for the amazing feedback we got for Fabric so far - keep it coming. The list of breaking changes is short and won't affect the vast majority of users.

Sharding Context Manager in

We removed automatic sharding support with or using fabric.launch(fn). This only impacts FSDP and DeepSpeed strategy users who use this way of launching. Please note that is a legacy construct from the LightningLite days, and is not recommended today. Please instantiate your large FSDP or DeepSpeed model under the newly added fabric.init_module context manager (#17832).


import lightning as L

def train(fabric):
    # FSDP's `enable_wrap` context or ``
    # were applied automaticaly here
    model = LargeModel()
fabric = L.Fabric()


def train(fabric):
    # Use `init_module` explicitly to apply these context managers
    with fabric.init_module():
        model = LargeModel()

Multi-GPU in Jupyter Notebooks

Due to lack of reliability, Fabric now only runs on one GPU instead of all GPUs in a Jupyter notebook if devices="auto" (default) (#18291)


import lightning as L

# In Jupyter notebooks, this would select all available GPUs (DDP)
fabric = L.Fabric(accelerator="cuda", devices="auto")


# In Jupyter notebooks, this now selects only one GPU (the first)
fabric = L.Fabric(accelerator="cuda", devices="auto")

# You can still explicitly select multiple
fabric = L.Fabric(accelerator="cuda", devices=8)


PyTorch Lightning

  • Added metrics_format attribute to RichProgressBarTheme class (#18373)
  • Added CHECKPOINT_EQUALS_CHAR attribute to ModelCheckpoint class (#17999)
  • Added **summarize_kwargs to ModelSummary and RichModelSummary callbacks (#16788)
  • Added support for the max_size_cycle|max_size|min_size iteration modes during evaluation (#17163)
  • Added support for the TPU-v4 architecture (#17227)
  • Added support for XLA's new PJRT runtime (#17352)
  • Check for invalid TPU device inputs (#17227)
  • Added XLAStrategy(sync_module_states=bool) to control whether to broadcast the parameters to all devices (#17522)
  • Added support for multiple optimizer parameter groups when using the FSDP strategy (#17309)
  • Enabled saving the full model state dict when using the FSDPStrategy (#16558)
  • Update LightningDataModule.from_datasets to support arbitrary iterables (#17402)
  • Run the DDP wrapper in a CUDA stream (#17334)
  • Added SaveConfigCallback.save_config to ease use cases such as saving the config to a logger (#17475)
  • Enabled optional file versioning of model checkpoints (#17320)
  • Added the process group timeout argument FSDPStrategy(timeout=...) for the FSDP strategy (#17274)
  • Added FSDPStrategy(activation_checkpointing_policy=...) to customize the layer policy for automatic activation checkpointing (requires torch>=2.1) (#18045)
  • Added CLI option --map-to-cpu to the checkpoint upgrade script to enable converting GPU checkpoints on a CPU-only machine (#17527)
  • Added non-layer param count to the model summary (#17005)
  • Updated LearningRateMonitor to log monitored values to trainer.callback_metrics (#17626)
  • Added log_weight_decay argument to LearningRateMonitor callback (#18439)
  • Added Trainer.print() to print on local rank zero only (#17980)
  • Added Trainer.init_module() context manager to instantiate large models efficiently directly on device, dtype (#18004)
    • Creates the model parameters in the desired dtype (torch.float32, torch.float64) depending on the 'true' precision choice in Trainer(precision='32-true'|'64-true')
  • Added the LightningModule.configure_model() hook to instantiate large models efficiently directly on device, dtype, and with sharding support (#18004)
    • Handles initialization for FSDP models before wrapping and the Zero stage 3 initialization for DeepSpeed before sharding
  • Added support for meta-device initialization with Trainer.init_module(empty_init=True) in FSDP (#18385)
  • Added lightning.pytorch.plugins.PrecisionPlugin.module_init_context() and lightning.pytorch.strategies.Strategy.tensor_init_context() context managers to control model and tensor instantiation (#18004)
  • Automatically call xla_model.mark_step() before saving checkpoints with XLA (#17882)
  • Added a callback for spike-detection (#18014)
  • Added the ability to set the torch.distributed.fsdp.ShardingStrategy via string in FSDPStrategy (#18087)
  • Improved error messages when attempting to load a DeepSpeed checkpoint at an invalid path (#17795)
  • Allowed accessing rank information in the main process before processes are launched when using the XLAStrategy (#18194)
  • Added support for true half-precision training via Trainer(precision="16-true"|"bf16-true") (#18193, #18217, #18213, #18219)
  • Added automatic process cleanup to avoid zombie child processes and stalls when exceptions are raised (#18218)
  • Added validation of user input for devices and num_nodes when running with SLURM or TorchElastic (#18292)
  • Added support for saving checkpoints with either full state-dict or sharded state dict via FSDPStrategy(state_dict_type="full"|"sharded") (#18364)
  • Added support for loading sharded/distributed checkpoints in FSDP (#18358)
  • Made the text delimiter in the rich progress bar configurable (#18372)
  • Improved the error messaging and instructions when handling custom batch samplers in distributed settings (#18402)
  • Added support for mixed 8-bit precision as Trainer(precision="transformer-engine") using Nvidia's Transformer Engine (#18459)
  • Added support for linear layer quantization with Trainer(plugins=BitsandbytesPrecision()) using bitsandbytes (#18655)
  • Added support for passing the process group to the FSDPStrategy (#18583)
  • Enabled the default process group configuration for FSDP's hybrid sharding (#18583)
  • Added lightning.pytorch.utilities.suggested_max_num_workers to assist with setting a good value in distributed settings (#18591)
  • Improved the num_workers warning to give a more accurate upper limit on the num_workers suggestion (#18591)
  • Added lightning.pytorch.utilities.is_shared_filesystem utility function to automatically check whether the filesystem is shared between machines (#18586)
  • Added support for returning an object of type Mapping from LightningModule.training_step() (#18657)
  • Added the hook LightningModule.on_validation_model_zero_grad() to allow overriding the behavior of zeroing the gradients before entering the validation loop (#18710)
  • Changed default metric formatting from round(..., 3) to ".3f" format string in MetricsTextColumn class (#18483)
  • Removed the limitation to call self.trainer.model.parameters() in LightningModule.configure_optimizers() (#17309)
  • Trainer(accelerator="tpu", devices=[i])" now selects the i-th TPU core (0-based, previously it was 1-based) (#17227)
  • Allow using iterable-style datasets with TPUs (#17331)
  • Increased the minimum XLA requirement to 1.13 (#17368)
  • self.loged tensors are now kept in the original device to reduce unnecessary host-to-device synchronizations (#17334)
  • Made the run initialization in WandbLogger lazy to avoid creating artifacts when the CLI is used (#17573)
  • Simplified redirection of *_step methods in strategies by removing the _LightningModuleWrapperBase wrapper module (#17531)
  • Support kwargs input for LayerSummary (#17709)
  • Dropped support for wandb versions older than 0.12.0 in WandbLogger (#17876)
  • During LightningModule.setup(), the self.device now returns the device the module will be placed on instead of cpu (#18021)
  • Increased the minimum supported wandb version for WandbLogger from 0.12.0 to 0.12.10 (#18171)
  • The input tensors now get cast to the right precision type before transfer to the device (#18264)
  • Improved the formatting of emitted warnings (#18288)
  • Broadcast and reduction of tensors with XLA-based strategies now preserve the input's device (#18275)
  • The FSDPStrategy now loads checkpoints after the configure_model/configure_sharded_model hook (#18358)
  • The FSDPStrategy.load_optimizer_state_dict and FSDPStrategy.load_model_state_dict are a no-op now (#18358)
  • The Trainer.num_val_batches, Trainer.num_test_batches and Trainer.num_sanity_val_batches now return a list of sizes per dataloader instead of a single integer (#18441)
  • The *_step(dataloader_iter) flavor now no longer takes the batch_idx in the signature (#18390)
  • Calling next(dataloader_iter) now returns a triplet (batch, batch_idx, dataloader_idx) (#18390)
  • Calling next(combined_loader) now returns a triplet (batch, batch_idx, dataloader_idx) (#18390)
  • Due to lack of reliability, Trainer now only runs on one GPU instead of all GPUs in a Jupyter notebook if devices="auto" (default) (#18291)
  • Made the batch_idx argument optional in validation_step, test_step and predict_step to maintain consistency with training_step (#18512)
  • The TQDMProgressBar now consistently shows it/s for the speed even when the iteration time becomes larger than one second (#18593)
  • The LightningDataModule.load_from_checkpoint and LightningModule.load_from_checkpoint methods now raise an error if they are called on an instance instead of the class (#18432)
  • Enabled launching via torchrun in a SLURM environment; the TorchElasticEnvironment now gets chosen over the SLURMEnvironment if both are detected (#18618)
  • If not set by the user, Lightning will set OMP_NUM_THREADS to num_cpus / num_processes when launching subprocesses (e.g. when DDP is used) to avoid system overload for CPU-intensive tasks (#18677)
  • The ModelCheckpoint no longer deletes files under the save-top-k mechanism when resuming from a folder that is not the same as the current checkpoint folder (#18750)
  • The ModelCheckpoint no longer deletes the file that was passed to (#18750)
  • Calling twice now raises an error with strategies that spawn subprocesses through multiprocessing (ddp_spawn, xla) (#18776)
  • The ModelCheckpoint now saves a symbolic link if save_last=True and save_top_k != 0 (#18748)
  • Deprecated the SingleTPUStrategy (strategy="single_tpu") in favor of SingleDeviceXLAStrategy (strategy="single_xla") (#17383)
  • Deprecated the TPUAccelerator in favor of XLAAccelerator (#17383)
  • Deprecated the TPUPrecisionPlugin in favor of XLAPrecisionPlugin (#17383)
  • Deprecated the TPUBf16PrecisionPlugin in favor of XLABf16PrecisionPlugin (#17383)
  • Deprecated the Strategy.post_training_step method (#17531)
  • Deprecated the LightningModule.configure_sharded_model hook in favor of LightningModule.configure_model (#18004)
  • Deprecated the LightningDoublePrecisionModule wrapper in favor of calling Trainer.precision_plugin.convert_input() (#18209)
  • Removed the XLAStrategy.is_distributed property. It is always True (#17381)
  • Removed the SingleTPUStrategy.is_distributed property. It is always False (#17381)
  • Removed experimental support for torchdistx due to a lack of project maintenance (#17995)
  • Removed support for PyTorch 1.11 (#18691)
  • Fixed an issue with reusing the same model across multiple trainer stages when using the DeepSpeedStrategy (#17531)
  • Fixed the saving and loading of FSDP optimizer states (#17819)
  • Fixed FSDP re-applying activation checkpointing when the user had manually applied it already (#18006)
  • Fixed issue where unexpected exceptions would leave the default torch dtype modified when using true precision settings (#18500)
  • Fixed issue where not including the batch_idx argument in the training_step would disable gradient accumulation (#18619)
  • Fixed the replacement of callbacks returned in LightningModule.configure_callbacks when the callback was a subclass of an existing Trainer callback (#18508)
  • Fixed Trainer.log_dir not returning the correct directory for the CSVLogger (#18548)
  • Fixed redundant input-type casting in FSDP precision (#18630)
  • Fixed numerical issues when reducing values in low precision with self.log (#18686)
  • Fixed an issue that would cause the gradients to be erased if validation happened in the middle of a gradient accumulation phase (#18710)
  • Fixed redundant file writes in CSVLogger (#18567)
  • Fixed an issue that could lead to checkpoint files being deleted accidentally when resuming training (#18750)

Lightning Fabric

  • Added support for the TPU-v4 architecture (#17227)
  • Added support for XLA's new PJRT runtime (#17352)
  • Added support for Fully Sharded Data Parallel (FSDP) training with XLA (#18126, #18424, #18430)
  • Check for invalid TPU device inputs (#17227)
  • Added XLAStrategy(sync_module_states=bool) to control whether to broadcast the parameters to all devices (#17522)
  • Added support for joint setup of model and optimizer with FSDP (#17305)
  • Added support for handling multiple parameter groups in optimizers set up with FSDP (#17305)
  • Added support for saving and loading sharded model and optimizer state with FSDPStrategy (#17323)
  • Added a warning when calling methods on _FabricModule that bypass the strategy-specific wrappers (#17424)
  • Added Fabric.init_tensor() context manager to instantiate tensors efficiently directly on device and dtype (#17488)
  • Added Fabric.init_module() context manager to instantiate large models efficiently directly on device, dtype, and with sharding support (#17462)
    • Creates the model parameters in the desired dtype (torch.float32, torch.float64, torch.float16, or torch.bfloat16) depending on the 'true' precision choice in Fabric(precision='32-true'|'64-true'|'16-true'|'bf16-true')
    • Handles initialization for FSDP models before wrapping and the Zero stage 3 initialization for DeepSpeed before sharding
  • Added support for empty weight initialization with Fabric.init_module(empty_init=True) for checkpoint loading (#17627)
  • Added support for meta-device initialization with Fabric.init_module(empty_init=True) in FSDP (#18122)
  • Added lightning.fabric.plugins.Precision.module_init_context() and lightning.fabric.strategies.Strategy.module_init_context() context managers to control model and tensor instantiation (#17462)
  • lightning.fabric.strategies.Strategy.tensor_init_context() context manager to instantiate tensors efficiently directly on device and dtype (#17607)
  • Run the DDP wrapper in a CUDA stream (#17334)
  • Added support for true half-precision as Fabric(precision="16-true"|"bf16-true") (#17287)
  • Added support for mixed 8-bit precision as Fabric(precision="transformer-engine") using Nvidia's Transformer Engine (#17597)
  • Added support for linear layer quantization with Fabric(plugins=BitsandbytesPrecision()) using bitsandbytes (#18655)
  • Added error messaging for missed .launch() when it is required (#17570)
  • Added support for saving checkpoints with either full state-dict or sharded state dict via FSDPStrategy(state_dict_type="full"|"sharded") (#17526)
  • Added support for loading a full-state checkpoint file into a sharded model (#17623)
  • Added support for calling hooks on a LightningModule via (#17874)
  • Added the parameter Fabric.load(..., strict=True|False) to enable non-strict loading of partial checkpoint state (#17645)
  • Added the parameter, filter=...) to enable saving a partial checkpoint state (#17845)
  • Added support for loading optimizer states from a full-state checkpoint file (#17747)
  • Automatically call xla_model.mark_step() before saving checkpoints with XLA (#17882)
  • Automatically call xla_model.mark_step() after optimizer.step() with XLA (#17883)
  • Added support for all half-precision modes in FSDP precision plugin (#17807)
  • Added FSDPStrategy(activation_checkpointing_policy=...) to customize the layer policy for automatic activation checkpointing (requires torch>=2.1) (#18045)
  • Added a callback for spike-detection (#18014)
  • Added the ability to set the torch.distributed.fsdp.ShardingStrategy via string in FSDPStrategy (#18087)
  • Improved error messages when attempting to load a DeepSpeed checkpoint at an invalid path (#17795)
  • Added Fabric.load_raw() for loading raw PyTorch state dict checkpoints for model or optimizer objects (#18049)
  • Allowed accessing rank information in the main process before processes are launched when using the XLAStrategy (#18194)
  • Added automatic process cleanup to avoid zombie child processes and stalls when exceptions are raised (#18218)
  • Added validation of user input for devices and num_nodes when running with SLURM or TorchElastic (#18292)
  • Improved the error messaging and instructions when handling custom batch samplers in distributed settings (#18402)
  • Added support for saving and loading stateful objects other than modules and optimizers (#18513)
  • Enabled the default process group configuration for FSDP's hybrid sharding (#18583)
  • Added lightning.fabric.utilities.suggested_max_num_workers to assist with setting a good value in distributed settings (#18591)
  • Added lightning.fabric.utilities.is_shared_filesystem utility function to automatically check whether the filesystem is shared between machines (#18586)
  • Removed support for PyTorch 1.11 (#18691)
  • Added support for passing the argument .load_state_dict(..., assign=True|False) on Fabric-wrapped modules in PyTorch 2.1 or newer (#18690)
  • Allow using iterable-style datasets with TPUs (#17331)
  • Increased the minimum XLA requirement to 1.13 (#17368)
  • Fabric argument validation now only raises an error if conflicting settings are set through the CLI (#17679)
  • DataLoader re-instantiation is now only performed when a distributed sampler is required (#18191)
  • Improved the formatting of emitted warnings (#18288)
  • Broadcast and reduction of tensors with XLA-based strategies now preserve the input's device (#18275)
  • Due to lack of reliability, Fabric now only runs on one GPU instead of all GPUs in a Jupyter notebook if devices="auto" (default) (#18291)
  • Enabled launching via torchrun in a SLURM environment; the TorchElasticEnvironment now gets chosen over the SLURMEnvironment if both are detected (#18618)
  • If not set by the user, Lightning will set OMP_NUM_THREADS to num_cpus / num_processes when launching subprocesses (e.g. when DDP is used) to avoid system overload for CPU-intensive tasks (#18677)
  • Deprecated the DDPStrategy.is_distributed property. This strategy is distributed by definition (#17381)
  • Deprecated the SingleTPUStrategy (strategy="single_tpu") in favor of SingleDeviceXLAStrategy (strategy="single_xla") (#17383)
  • Deprecated the TPUAccelerator in favor of XLAAccelerator (#17383)
  • Deprecated the TPUPrecision in favor of XLAPrecision (#17383)
  • Deprecated the TPUBf16Precision in favor of XLABf16Precision (#17383)
  • Removed automatic sharding support with or using fabric.launch(fn). This only impacts FSDP and DeepSpeed strategy users. Please instantiate your module under the newly added fabric.init_module context manager (#17832)
  • Removed the unsupported checkpoint_io argument from the FSDPStrategy (#18192)
  • Fixed issue where running on TPUs would select the wrong device index (#17227)
  • Removed the need to call .launch() when using the DP-strategy (strategy="dp") (#17931)
  • Fixed FSDP re-applying activation checkpointing when the user had manually applied it already (#18006)
  • Fixed FSDP re-wrapping the module root when the user had manually wrapped the model (#18054)
  • Fixed issue where unexpected exceptions would leave the default torch dtype modified when using true precision settings (#18500)
  • Fixed redundant input-type casting in FSDP precision (#18630)
  • Fixed an issue with find_usable_cuda_devices(0) incorrectly returning a list of devices (#18722)
  • Fixed redundant file writes in CSVLogger (#18567)

Lightning App

  • Allow customizing gradio components with lightning colors (#17054)
  • Changed LocalSourceCodeDir cache_location to not use home in some certain cases (#17491)
  • Remove cluster commands from the CLI (#18151)

Full commit list: 2.0.0...2.1.0



@adamjstewart @akreuzer @ethanwharris @dmitsf @lantiga @nicolai86 @pl-ghost @carmocca @awaelchli @justusschock @edenlightning @belerico @lightningforever @nisheethlahoti @tchaton @yurijmikhalevich @mauvilsa @rlizzo @rusmux @yhl48 @Liyang90 @jerome-habana @JustinGoheen @Borda @speediedan @SkafteNicki @dcfidalgo


@saryazdi @parambharat @kshitij12345 @woqidaideshi @colehawkins @md-121 @gkroiz @idc9 @BoringDonut @OmerShubi @ishandutta0098 @ryan597 @leng-yue @alicanb @One-sixth @santurini @SpirinEgor @KogaiIrina @shanmugamr1992 @janeyx99 @asmith26 @dingusagar @AleksanderWWW @strawberrypie @solyaH @kaczmarj @voidful @water-vapor @bkiat1123 @rhiga2 @baskrahmer @felipewhitaker @mukhery @Quasar-Kim @robieta @one-matrix @jere357 @schmidt-ai @schuhschuh @anio @rjarun8 @callumhay @minhlong94 @klieret @giorgioskij @shihaoyin @JonathanRayner @NripeshN @marcimarc1 @bilelomrani1 @NikolasWolke @0x404 @quintenroets @Borodin @amorehead @SebastianGer @ioangatop @Tribhuvan0 @f0k @sameertantry @kwsp @nik777 @matsumotosan

When Chuck Norris trains a neural network, it not only learns, but it also gains the ability to defend itself from adversarial attacks by roundhouse kicking them into submission.

