PyTorch Tensor Corruption Bug: Failed Resizes Explained
Have you ever encountered a mysterious crash or unexpected behavior in your PyTorch code, especially when dealing with tensors that seem to have their dimensions all mixed up? You're not alone! There's a subtle but significant bug within PyTorch, specifically related to how tensor metadata is updated during storage resize operations. This issue can leave your tensors in a corrupted state, often referred to as a "Zombie" tensor, leading to segmentation faults and other runtime errors. Let's dive deep into what causes this problem, why it's so tricky, and what it means for your machine learning workflows.
Understanding the "Zombie" Tensor Phenomenon
The core of the problem lies in the interaction between PyTorch tensors, their underlying storage, and the resize_() operation. When you call resize_() on a tensor, PyTorch attempts to change the shape and size of the tensor's data. However, this operation is only possible if the tensor's storage is resizable. Many tensors have resizable storage by default, but problems arise when a tensor's storage is not resizable. This often happens when a tensor is created from, or shares storage with, external data structures like NumPy arrays that have been directly injected into PyTorch using methods like set_().
When resize_() is called on such a tensor, PyTorch correctly identifies that the storage cannot be resized and raises a RuntimeError with a clear message: "Trying to resize storage that is not resizable." This is the expected behavior for error handling. However, the bug lies in the fact that the operation isn't exception-safe. Before PyTorch can definitively determine that the storage is not resizable, it proceeds to update the tensor's shape and stride metadata to reflect the new target size specified in the resize_() call. It's only after this metadata update that the check for resizable storage fails.
This sequence of events leaves the tensor in a deeply inconsistent state. The tensor's shape attribute might now report a seemingly valid, larger size (e.g., torch.Size([5, 5, 5])), but its actual underlying storage() remains unchanged and, crucially, empty (0 bytes). This mismatch between the advertised shape and the actual available data is what creates the "Zombie" tensor. It looks like a functional tensor with specific dimensions, but it has no data backing it up. Accessing or printing such a tensor, which expects data based on its shape, can lead to catastrophic failures like segmentation faults or internal PyTorch RuntimeErrors because the program tries to read data that doesn't exist or is in an inaccessible location. This bug is particularly insidious because the error doesn't manifest immediately at the resize_() call itself; instead, it surfaces later when you try to use the corrupted tensor, making debugging a real challenge. It’s like having a blueprint for a huge building but only enough materials for a small shed – any attempt to build the large structure will inevitably fail, and in this case, it can bring down the whole program.
A Minimal Reproduction: Witnessing the Corruption
To truly understand the severity and nature of this bug, it's essential to see it in action. The PyTorch development team has provided a minimal reproduction case that clearly illustrates the issue. Let's break down the code and what it demonstrates:
First, we need to set up a scenario where a tensor has non-resizable storage. This is achieved by creating a NumPy array with zero elements and then converting its underlying storage into a PyTorch untyped_storage. This locked_storage is essentially an empty buffer that PyTorch cannot resize.
import torch
import numpy as np
# Create non-resizable storage (0 bytes)
locked_storage = torch.from_numpy(np.array([], dtype=np.int32)).untyped_storage()
# Inject into a fresh tensor
t = torch.tensor([], dtype=torch.int32)
t.set_(locked_storage)
Here, locked_storage is created from an empty NumPy array, so its nbytes() will be 0. The t.set_(locked_storage) command essentially attaches this empty, non-resizable storage to our new PyTorch tensor t. At this point, t is a valid, albeit empty, tensor with shape likely torch.Size([0]) and storage().nbytes() equal to 0.
Now comes the critical part: attempting to resize this tensor. We use a try-except block because we expect a RuntimeError to be raised when we try to resize a tensor with locked storage. However, the bug reveals itself in what happens inside that except block:
# Attempt to resize (Expected: Fail, maintain original shape)
# (Actual: Fails, but updates shape to 5x5x5)
try:
t.resize_((5, 5, 5))
except RuntimeError:
pass # The error is caught here
According to the expected behavior, if resize_() fails due to non-resizable storage, the tensor's metadata (its shape and strides) should remain unchanged. It should still reflect its original state, which was an empty tensor. However, the actual behavior is that the RuntimeError is raised, but after the tensor's shape metadata has already been modified to torch.Size([5, 5, 5]). The storage size, however, remains at 0 bytes.
Finally, we can verify the corruption. The code prints the shape and the storage size. The output is revealing:
# Verify corruption
print(f"Shape: {t.shape}") # Prints: torch.Size([5, 5, 5])
print(f"Storage: {t.untyped_storage().nbytes()}") # Prints: 0
print(t) # CRASH
The output clearly shows the discrepancy: the shape is now torch.Size([5, 5, 5]), indicating a tensor that should hold elements. Yet, t.untyped_storage().nbytes() reports 0 bytes, meaning there is no data allocated for these elements. The final print(t) line is where the program typically crashes, either with a segmentation fault or another internal error, because it attempts to access and display data that simply doesn't exist in the 0-byte storage. This minimal example effectively isolates the bug, demonstrating how a failed resize_() operation can corrupt a tensor's metadata, leading to instability.
The Importance of Exception Safety and Strong Guarantees
In software engineering, exception safety is a crucial concept, especially in libraries that deal with low-level operations and complex data structures like PyTorch. When an operation fails and throws an exception, the system should ideally be left in a known, consistent state. There are different levels of exception safety guarantees:
- Basic Guarantee: If an exception is thrown, the program remains in a valid state, meaning no memory leaks or corruption occur. However, the state of the affected objects might be unpredictable.
- Strong Guarantee: If an exception is thrown, the program state is rolled back to exactly what it was before the operation. No changes are made, and no resources are lost. This is often referred to as