Serialization Py
Serialization Py
import os
import io
import shutil
import struct
import sys
import torch
import tarfile
import tempfile
import warnings
from contextlib import closing, contextmanager
from enum import Enum
from ._utils import _import_dotted_name
from torch._sources import get_source_lines_and_file
from torch.types import Storage
from torch.storage import _get_dtype_from_pickle_storage_type
from typing import Any, BinaryIO, Callable, cast, Dict, Optional, Type, Tuple,
Union, IO, List
from typing_extensions import TypeAlias, TypeGuard # Python 3.10+
import copyreg
import pickle
import torch._weights_only_unpickler as _weights_only_unpickler
DEFAULT_PROTOCOL = 2
LONG_SIZE = struct.Struct('=l').size
INT_SIZE = struct.Struct('=i').size
SHORT_SIZE = struct.Struct('=h').size
MAGIC_NUMBER = 0x1950a86a20f9469cfc6c
PROTOCOL_VERSION = 1001
STORAGE_KEY_SEPARATOR = ','
__all__ = [
'SourceChangeWarning',
'mkdtemp',
'register_package',
'check_module_version_greater_or_equal',
'validate_cuda_device',
'validate_hpu_device',
'location_tag',
'default_restore_location',
'normalize_storage_type',
'storage_to_tensor_type',
'save',
'load',
'StorageType',
'LoadEndianness',
'get_default_load_endianness',
'set_default_load_endianness',
]
class SourceChangeWarning(Warning):
pass
@contextmanager
def mkdtemp():
path = tempfile.mkdtemp()
try:
yield path
finally:
shutil.rmtree(path)
class LoadEndianness(Enum):
NATIVE = 1
LITTLE = 2
BIG = 3
Returns:
default_load_endian: Optional[LoadEndianness]
'''
return _default_load_endian
def set_default_load_endianness(endianness):
'''
Set fallback byte order for loading files
Args:
endianness: the new fallback byte order
'''
global _default_load_endian
if not isinstance(endianness, LoadEndianness) and endianness is not None:
raise TypeError("Invalid argument type in function
set_default_load_endianness")
_default_load_endian = endianness
start = f.tell()
# Read the first few bytes and match against the ZIP file signature
local_header_magic_number = b'PK\x03\x04'
read_bytes = f.read(len(local_header_magic_number))
f.seek(start)
return read_bytes == local_header_magic_number
def register_package(
priority: int,
tagger: Callable[[STORAGE], Optional[str]],
deserializer: Callable[[STORAGE, str], Optional[STORAGE]]
):
'''
Registers callables for tagging and deserializing storage objects with an
associated priority.
Tagging associates a device with a storage object at save time while
deserializing moves a
storage object to an appropriate device at load time. :attr:`tagger`
and :attr:`deserializer`
are run in the order given by their :attr:`priority` until a
tagger/deserializer returns a
value that is not `None`.
This function can also be used to register a tagger and deserializer for
new devices.
Args:
priority: Indicates the priority associated with the tagger and
deserializer, where a lower
value indicates higher priority.
tagger: Callable that takes in a storage object and returns its
tagged device as a string
or None.
deserializer: Callable that takes in storage object and a device
string and returns a storage
object on the appropriate device or None.
Returns:
`None`
Example:
>>> def ipu_tag(obj):
>>> if obj.device.type == 'ipu':
>>> return 'ipu'
>>> def ipu_deserialize(obj, location):
>>> if location.startswith('ipu'):
>>> ipu = getattr(torch, "ipu", None)
>>> assert ipu is not None, "IPU device module is not loaded"
>>> assert torch.ipu.is_available(), "ipu is not available"
>>> return obj.ipu(location)
>>> torch.serialization.register_package(11, ipu_tag,
ipu_deserialize)
'''
queue_elem = (priority, tagger, deserializer)
_package_registry.append(queue_elem)
_package_registry.sort()
Args:
module: the module to check the version of
req_version_tuple: tuple (usually of ints) representing the required
version
error_if_malformed: whether we should exit if module version string
is malformed
Returns:
requirement_is_met: bool
'''
try:
version_strs = module.__version__.split('.')
# Cast module version fields to match the types of the required
version
module_version = tuple(
type(req_field)(version_strs[idx]) for idx, req_field in
enumerate(req_version_tuple)
)
requirement_is_met = module_version >= req_version_tuple
except Exception as e:
message = (
f"'{module.__name__}' module version string is malformed
'{module.__version__}' and cannot be compared"
f" with tuple {str(req_version_tuple)}"
)
if error_if_malformed:
raise RuntimeError(message) from e
else:
warnings.warn(message + ', but continuing assuming that
requirement is met')
requirement_is_met = True
return requirement_is_met
def _cpu_tag(obj):
if obj.device.type == 'cpu':
return 'cpu'
def _cuda_tag(obj):
if obj.device.type == 'cuda':
return 'cuda:' + str(obj.device.index)
def _hpu_tag(obj):
if obj.device.type == 'hpu':
return 'hpu:' + str(obj.device.index)
def _mps_tag(obj):
if obj.device.type == 'mps':
return 'mps'
def _meta_tag(obj):
if obj.device.type == 'meta':
return 'meta'
def _privateuse1_tag(obj):
backend_name = torch._C._get_privateuse1_backend_name()
if obj.device.type == backend_name:
if obj.device.index is None:
return backend_name
else:
return backend_name + ':' + str(obj.device.index)
def validate_cuda_device(location):
device = torch.cuda._utils._get_device_index(location, True)
if not torch.cuda.is_available():
raise RuntimeError('Attempting to deserialize object on a CUDA '
'device but torch.cuda.is_available() is False. '
'If you are running on a CPU-only machine, '
'please use torch.load with
map_location=torch.device(\'cpu\') '
'to map your storages to the CPU.')
device_count = torch.cuda.device_count()
if device >= device_count:
raise RuntimeError('Attempting to deserialize object on CUDA device '
f'{device} but torch.cuda.device_count() is
{device_count}. Please use '
'torch.load with map_location to map your storages
'
'to an existing device.')
return device
def validate_hpu_device(location):
hpu = getattr(torch, "hpu", None)
assert hpu is not None, "HPU device module is not loaded"
device = hpu._utils._get_device_index(location, optional=True)
if not hpu.is_available():
raise RuntimeError('Attempting to deserialize object on a HPU '
'device but torch.hpu.is_available() is False. '
'If you are running on a CPU-only machine, '
'please use torch.load with
map_location=torch.device(\'cpu\') '
'to map your storages to the CPU.')
device_count = hpu.device_count()
if device >= device_count:
raise RuntimeError('Attempting to deserialize object on HPU device '
f'{device} but torch.hpu.device_count() is
{device_count}. Please use '
'torch.load with map_location to map your storages
'
'to an existing device.')
return device
Args:
location: string of device
backend_name: the name of privateuse1, which can be renamed
Returns:
device_index: int
'''
if not hasattr(torch, backend_name):
raise RuntimeError(f'The {backend_name.upper()} device module is not
registered. '
'If you are running on a CPU-only machine, '
'please use torch.load with
map_location=torch.device(\'cpu\') '
'to map your storages to the CPU.')
device_module = getattr(torch, backend_name)
if hasattr(device_module, '_utils') and hasattr(device_module._utils,
'_get_device_index'):
device_index = device_module._utils._get_device_index(location, True)
else:
device = torch.device(location)
device_index = device.index if device.index else 0
if hasattr(device_module, 'is_available') and not
device_module.is_available():
raise RuntimeError(f'Attempting to deserialize object on a
{backend_name.upper()} '
f'device but torch.{backend_name}.is_available() is
False. '
'If you are running on a CPU-only machine, '
'please use torch.load with
map_location=torch.device(\'cpu\') '
'to map your storages to the CPU.')
if hasattr(device_module, 'device_count'):
device_count = device_module.device_count()
if device_index >= device_count:
raise RuntimeError(f'Attempting to deserialize object on
{backend_name.upper()} device '
f'{device_index} but torch.
{backend_name}.device_count() is {device_count}. '
'Please use torch.load with map_location to map
your storages '
'to an existing device.')
return device_index
def _privateuse1_deserialize(obj, location):
backend_name = torch._C._get_privateuse1_backend_name()
if location.startswith(backend_name):
if not hasattr(obj, backend_name):
raise RuntimeError(f'Attempting to load the storages to the
{backend_name.upper()} device '
f'but torch.storage._StorageBase.{backend_name}
() or '
f'torch.storage.TypedStorage.{backend_name}()
is not generated. '
'Please use
torch.utils.generate_methods_for_privateuse1_backend '
f'to generate storage.{backend_name}() method
first.')
device_index = _validate_privateuse1_device(location, backend_name)
return getattr(obj, backend_name)(device_index)
def normalize_storage_type(storage_type):
return getattr(torch, storage_type.__name__)
def storage_to_tensor_type(storage):
storage_type = type(storage)
module = _import_dotted_name(storage_type.__module__)
return getattr(module, storage_type.__name__.replace('Storage', 'Tensor'))
def _is_path(name_or_buffer) -> TypeGuard[Union[str, os.PathLike]]:
return isinstance(name_or_buffer, (str, os.PathLike))
class _opener:
def __init__(self, file_like):
self.file_like = file_like
def __enter__(self):
return self.file_like
class _open_file(_opener):
def __init__(self, name, mode):
super().__init__(open(name, mode))
class _open_buffer_reader(_opener):
def __init__(self, buffer):
super().__init__(buffer)
_check_seekable(buffer)
class _open_buffer_writer(_opener):
def __exit__(self, *args):
self.file_like.flush()
class _open_zipfile_reader(_opener):
def __init__(self, name_or_buffer) -> None:
super().__init__(torch._C.PyTorchFileReader(name_or_buffer))
class _open_zipfile_writer_file(_opener):
def __init__(self, name) -> None:
self.file_stream = None
self.name = str(name)
try:
self.name.encode('ascii')
except UnicodeEncodeError:
# PyTorchFileWriter only supports ascii filename.
# For filenames with non-ascii characters, we rely on Python
# for writing out the file.
self.file_stream = io.FileIO(self.name, mode='w')
super().__init__(torch._C.PyTorchFileWriter(self.file_stream))
else:
super().__init__(torch._C.PyTorchFileWriter(self.name))
class _open_zipfile_writer_buffer(_opener):
def __init__(self, buffer) -> None:
if not callable(getattr(buffer, "write", None)):
msg = f"Buffer of {str(type(buffer)).strip('<>')} has no callable
attribute 'write'"
if not hasattr(buffer, "write"):
raise AttributeError(msg)
raise TypeError(msg)
self.buffer = buffer
super().__init__(torch._C.PyTorchFileWriter(buffer))
def _open_zipfile_writer(name_or_buffer):
container: Type[_opener]
if _is_path(name_or_buffer):
container = _open_zipfile_writer_file
else:
container = _open_zipfile_writer_buffer
return container(name_or_buffer)
def _should_read_directly(f):
"""
Checks if f is a file that should be read directly. It should be read
directly if it is backed by a real file (has a fileno) and is not a
a compressed file (e.g. gzip)
"""
if _is_compressed_file(f):
return False
try:
return f.fileno() >= 0
except io.UnsupportedOperation:
return False
except AttributeError:
return False
try:
f.seek(f.tell())
return True
except (io.UnsupportedOperation, AttributeError) as e:
raise_err_msg(["seek", "tell"], e)
return False
Args:
pickle_module: module used for pickling metadata and objects
'''
if pickle_module is not None and pickle_module.__name__ == 'dill':
required_dill_version = (0, 3, 1)
if not check_module_version_greater_or_equal(pickle_module,
required_dill_version, False):
raise ValueError((
"'torch' supports dill >= {}, but you have dill {}."
" Please upgrade dill or switch to 'pickle'"
).format(
'.'.join([str(num) for num in required_dill_version]),
pickle_module.__version__
))
def _check_save_filelike(f):
if not _is_path(f) and not hasattr(f, 'write'):
raise AttributeError(
"expected 'f' to be string, path, or a file-like object with "
"a 'write' attribute")
def save(
obj: object,
f: FILE_LIKE,
pickle_module: Any = pickle,
pickle_protocol: int = DEFAULT_PROTOCOL,
_use_new_zipfile_serialization: bool = True,
_disable_byteorder_record: bool = False
) -> None:
# Reference: https://github.com/pytorch/pytorch/issues/54354
# The first line of this docstring overrides the one Sphinx generates for
the
# documentation. We need it so that Sphinx doesn't leak `pickle`s path
from
# the build environment (e.g. `<module 'pickle' from '/leaked/path').
Args:
obj: saved object
f: a file-like object (has to implement write and flush) or a string
or
os.PathLike object containing a file name
pickle_module: module used for pickling metadata and objects
pickle_protocol: can be specified to override the default protocol
.. note::
A common PyTorch convention is to save tensors using .pt file
extension.
.. note::
PyTorch preserves storage sharing across serialization. See
:ref:`preserve-storage-sharing` for more details.
.. note::
The 1.6 release of PyTorch switched ``torch.save`` to use a new
zipfile-based file format. ``torch.load`` still retains the ability
to
load files in the old format. If for any reason you want
``torch.save``
to use the old format, pass the kwarg
``_use_new_zipfile_serialization=False``.
Example:
>>> # xdoctest: +SKIP("makes cwd dirty")
>>> # Save to file
>>> x = torch.tensor([0, 1, 2, 3, 4])
>>> torch.save(x, 'tensor.pt')
>>> # Save to io.BytesIO buffer
>>> buffer = io.BytesIO()
>>> torch.save(x, buffer)
"""
torch._C._log_api_usage_once("torch.save")
_check_dill_version(pickle_module)
_check_save_filelike(f)
if _use_new_zipfile_serialization:
with _open_zipfile_writer(f) as opened_zipfile:
_save(obj, opened_zipfile, pickle_module, pickle_protocol,
_disable_byteorder_record)
return
else:
with _open_file_like(f, 'wb') as opened_file:
_legacy_save(obj, opened_file, pickle_module, pickle_protocol)
# Since loading storages that view the same data with different dtypes is
# not supported, we need to keep track of the dtype associated with each
# storage data_ptr and throw an error if the dtype is ever different.
# TODO: This feature could be added in the future
storage_dtypes: Dict[int, torch.dtype] = {}
if isinstance(obj, torch.storage.TypedStorage) or
torch.is_storage(obj):
storage: torch.UntypedStorage
if isinstance(obj, torch.storage.TypedStorage):
# TODO: Once we decide to break serialization FC, this case
# can be deleted
storage = obj._untyped_storage
storage_dtype = obj.dtype
storage_type_str = obj._pickle_storage_type()
storage_type = getattr(torch, storage_type_str)
dtype = obj.dtype
storage_numel = obj._size()
res = ('storage',
storage_type,
storage_key,
location,
storage_numel,
view_metadata)
return res
return None
sys_info = dict(
protocol_version=PROTOCOL_VERSION,
little_endian=sys.byteorder == 'little',
type_sizes=dict(
short=SHORT_SIZE,
int=INT_SIZE,
long=LONG_SIZE,
),
)
pickle_module.dump(MAGIC_NUMBER, f, protocol=pickle_protocol)
pickle_module.dump(PROTOCOL_VERSION, f, protocol=pickle_protocol)
pickle_module.dump(sys_info, f, protocol=pickle_protocol)
pickler = pickle_module.Pickler(f, protocol=pickle_protocol)
pickler.persistent_id = persistent_id
pickler.dump(obj)
serialized_storage_keys = sorted(serialized_storages.keys())
pickle_module.dump(serialized_storage_keys, f, protocol=pickle_protocol)
f.flush()
for key in serialized_storage_keys:
storage, dtype = serialized_storages[key]
storage._write_file(f, _should_read_directly(f), True,
torch._utils._element_size(dtype))
def _save(obj, zip_file, pickle_module, pickle_protocol,
_disable_byteorder_record):
serialized_storages = {}
id_map: Dict[int, str] = {}
# Since loading storages that view the same data with different dtypes is
# not supported, we need to keep track of the dtype associated with each
# storage data_ptr and throw an error if the dtype is ever different.
# TODO: This feature could be added in the future
storage_dtypes: Dict[int, torch.dtype] = {}
def persistent_id(obj):
# FIXME: the docs say that persistent_id should only return a string
# but torch store returns tuples. This works only in the binary
protocol
# see
# https://docs.python.org/2/library/pickle.html#pickling-and-
unpickling-external-objects
# https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-
L537
if isinstance(obj, torch.storage.TypedStorage) or
torch.is_storage(obj):
if isinstance(obj, torch.storage.TypedStorage):
# TODO: Once we decide to break serialization FC, this case
# can be deleted
storage = obj._untyped_storage
storage_dtype = obj.dtype
storage_type_str = obj._pickle_storage_type()
storage_type = getattr(torch, storage_type_str)
storage_numel = obj._size()
else:
storage = obj
storage_dtype = torch.uint8
storage_type = normalize_storage_type(type(obj))
storage_numel = storage.nbytes()
return ('storage',
storage_type,
storage_key,
location,
storage_numel)
return None
def load(
f: FILE_LIKE,
map_location: MAP_LOCATION = None,
pickle_module: Any = None,
*,
weights_only: bool = False,
mmap: Optional[bool] = None,
**pickle_load_args: Any
) -> Any:
# Reference: https://github.com/pytorch/pytorch/issues/54354
# The first line of this docstring overrides the one Sphinx generates for
the
# documentation. We need it so that Sphinx doesn't leak `pickle`s path
from
# the build environment (e.g. `<module 'pickle' from '/leaked/path').
User extensions can register their own location tags and tagging and
deserialization methods
using :func:`torch.serialization.register_package`.
Args:
f: a file-like object (has to
implement :meth:`read`, :meth:`readline`, :meth:`tell`, and :meth:`seek`),
or a string or os.PathLike object containing a file name
map_location: a function, :class:`torch.device`, string or a dict
specifying how to remap storage
locations
pickle_module: module used for unpickling metadata and objects (has
to
match the :attr:`pickle_module` used to serialize file)
weights_only: Indicates whether unpickler should be restricted to
loading only tensors, primitive types and dictionaries
mmap: Indicates whether the file should be mmaped rather than loading
all the storages into memory.
Typically, tensor storages in the file will first be moved from
disk to CPU memory, after which they
are moved to the location that they were tagged with when saving,
or specified by ``map_location``. This
second step is a no-op if the final location is CPU. When the
``mmap`` flag is set, instead of copying the
tensor storages from disk to CPU memory in the first step, ``f``
is mmaped.
pickle_load_args: (Python 3 only) optional keyword arguments passed
over to
:func:`pickle_module.load` and :func:`pickle_module.Unpickler`,
e.g.,
:attr:`errors=...`.
.. warning::
:func:`torch.load()` unless `weights_only` parameter is set to
`True`,
uses ``pickle`` module implicitly, which is known to be insecure.
It is possible to construct malicious pickle data which will execute
arbitrary code
during unpickling. Never load data that could have come from an
untrusted
source in an unsafe mode, or that could have been tampered with.
**Only load data you trust**.
.. note::
When you call :func:`torch.load()` on a file which contains GPU
tensors, those tensors
will be loaded to GPU by default. You can call ``torch.load(..,
map_location='cpu')``
and then :meth:`load_state_dict` to avoid GPU RAM surge when loading
a model checkpoint.
.. note::
By default, we decode byte strings as ``utf-8``. This is to avoid a
common error
case ``UnicodeDecodeError: 'ascii' codec can't decode byte 0x...``
when loading files saved by Python 2 in Python 3. If this default
is incorrect, you may use an extra :attr:`encoding` keyword argument
to specify how
these objects should be loaded, e.g., :attr:`encoding='latin1'`
decodes them
to strings using ``latin1`` encoding, and :attr:`encoding='bytes'`
keeps them
as byte arrays which can be decoded later with
``byte_array.decode(...)``.
Example:
>>> # xdoctest: +SKIP("undefined filepaths")
>>> torch.load('tensors.pt', weights_only=True)
# Load all tensors onto the CPU
>>> torch.load('tensors.pt', map_location=torch.device('cpu'),
weights_only=True)
# Load all tensors onto the CPU, using a function
>>> torch.load('tensors.pt', map_location=lambda storage, loc:
storage, weights_only=True)
# Load all tensors onto GPU 1
>>> torch.load('tensors.pt', map_location=lambda storage, loc:
storage.cuda(1), weights_only=True)
# Map tensors from GPU 1 to GPU 0
>>> torch.load('tensors.pt', map_location={'cuda:1': 'cuda:0'},
weights_only=True)
# Load tensor from io.BytesIO object
# Loading from a buffer setting weights_only=False, warning this can
be unsafe
>>> with open('tensor.pt', 'rb') as f:
... buffer = io.BytesIO(f.read())
>>> torch.load(buffer, weights_only=False)
# Load a module with 'ascii' encoding for unpickling
# Loading from a module setting weights_only=False, warning this can
be unsafe
>>> torch.load('module.pt', encoding='ascii', weights_only=False)
"""
torch._C._log_api_usage_once("torch.load")
UNSAFE_MESSAGE = (
"Weights only load failed. Re-running `torch.load` with `weights_only`
set to `False`"
" will likely succeed, but it can result in arbitrary code execution."
"Do it only if you get the file from a trusted source.
WeightsUnpickler error: "
)
# Add ability to force safe only weight loads via environment variable
if os.getenv("TORCH_FORCE_WEIGHTS_ONLY_LOAD", "0").lower() in ['1', 'y',
'yes', 'true']:
weights_only = True
if weights_only:
if pickle_module is not None:
raise RuntimeError("Can not safely load weights when explicit
pickle_module is specified")
else:
if pickle_module is None:
pickle_module = pickle
_check_dill_version(pickle_module)
f"`torch.save({f_name}_use_new_zipfile_serialization=True), "
"please torch.save your checkpoint with this
option in order to use mmap.")
if weights_only:
try:
return _legacy_load(opened_file, map_location,
_weights_only_unpickler, **pickle_load_args)
except RuntimeError as e:
raise pickle.UnpicklingError(UNSAFE_MESSAGE + str(e)) from
None
return _legacy_load(opened_file, map_location, pickle_module,
**pickle_load_args)
# There are yet not good way to type annotate function attributes
https://github.com/python/mypy/issues/2087
_get_layout.cache = {} # type: ignore[attr-defined]
copyreg.pickle(torch.layout, lambda obj: (_get_layout, (str(obj),)))
restore_location = _get_restore_location(map_location)
def legacy_load(f):
deserialized_objects: Dict[int, Any] = {}
def persistent_load(saved_id):
if isinstance(saved_id, tuple):
# Ignore containers that don't have any sources saved
if all(saved_id[1:]):
_check_container_source(*saved_id)
return saved_id[0]
return deserialized_objects[int(saved_id)]
tar.extract('storages', path=tmpdir)
with open(os.path.join(tmpdir, 'storages'), 'rb', 0) as f:
num_storages = pickle_module.load(f, **pickle_load_args)
for i in range(num_storages):
args = pickle_module.load(f, **pickle_load_args)
key, location, storage_type = args
dtype = storage_type._dtype
obj = cast(Storage,
torch.UntypedStorage)._new_with_file(f, torch._utils._element_size(dtype))
obj = restore_location(obj, location)
# TODO: Once we decide to break serialization FC, we can
# stop wrapping with TypedStorage
deserialized_objects[key] = torch.storage.TypedStorage(
wrap_storage=obj,
dtype=dtype,
_internal=True)
wrap_storage=root._untyped_storage[offset_bytes:offset_bytes + numel *
element_size],
dtype=root.dtype,
_internal=True)
tar.extract('tensors', path=tmpdir)
with open(os.path.join(tmpdir, 'tensors'), 'rb', 0) as f:
num_tensors = pickle_module.load(f, **pickle_load_args)
for _ in range(num_tensors):
args = pickle_module.load(f, **pickle_load_args)
key, storage_id, original_tensor_type = args
storage = deserialized_objects[storage_id]
ndim, = struct.unpack('<i', f.read(4))
# skip next 4 bytes; legacy encoding treated ndim as 8
bytes
f.read(4)
numel = struct.unpack(f'<{ndim}q', f.read(8 * ndim))
stride = struct.unpack(f'<{ndim}q', f.read(8 * ndim))
storage_offset, = struct.unpack('<q', f.read(8))
tensor = torch.empty((0,), dtype=storage.dtype).set_(
storage._untyped_storage, storage_offset, numel,
stride)
deserialized_objects[key] = tensor
pickle_file = tar.extractfile('pickle')
unpickler = UnpicklerWrapper(pickle_file, **pickle_load_args)
unpickler.persistent_load = persistent_load
result = unpickler.load()
return result
deserialized_objects = {}
def persistent_load(saved_id):
assert isinstance(saved_id, tuple)
typename = _maybe_decode_ascii(saved_id[0])
data = saved_id[1:]
if typename == 'module':
# Ignore containers that don't have any sources saved
if all(data[1:]):
_check_container_source(*data)
return data[0]
elif typename == 'storage':
storage_type, root_key, location, numel, view_metadata = data
location = _maybe_decode_ascii(location)
dtype = storage_type.dtype
wrap_storage=typed_storage._untyped_storage[offset_bytes:offset_bytes +
view_size_bytes],
dtype=dtype,
_internal=True)
res = deserialized_objects[view_key]
else:
res = typed_storage
return res
else:
raise RuntimeError(f"Unknown saved id type: {saved_id[0]}")
_check_seekable(f)
f_should_read_directly = _should_read_directly(f)
if torch._guards.active_fake_mode() is None:
offset = f.tell() if f_should_read_directly else None
for key in deserialized_storage_keys:
assert key in deserialized_objects
typed_storage = deserialized_objects[key]
typed_storage._untyped_storage._set_from_file(
f, offset, f_should_read_directly,
torch._utils._element_size(typed_storage.dtype))
if offset is not None:
offset = f.tell()
torch._utils._validate_loaded_sparse_tensors()
return result
def _get_restore_location(map_location):
if map_location is None:
restore_location = default_restore_location
elif isinstance(map_location, dict):
def restore_location(storage, location):
location = map_location.get(location, location)
return default_restore_location(storage, location)
elif isinstance(map_location, (str, bytes)):
def restore_location(storage, location):
return default_restore_location(storage, map_location)
elif isinstance(map_location, torch.device):
def restore_location(storage, location):
return default_restore_location(storage, str(map_location))
else:
def restore_location(storage, location):
result = map_location(storage, location)
if result is None:
result = default_restore_location(storage, location)
return result
return restore_location
class StorageType:
def __init__(self, name):
self._dtype = _get_dtype_from_pickle_storage_type(name)
@property
def dtype(self):
return self._dtype
def __str__(self):
return f'StorageType(dtype={self.dtype})'
loaded_storages = {}
if typed_storage._data_ptr() != 0:
loaded_storages[key] = typed_storage
return typed_storage
def persistent_load(saved_id):
assert isinstance(saved_id, tuple)
typename = _maybe_decode_ascii(saved_id[0])
data = saved_id[1:]
if key in loaded_storages:
typed_storage = loaded_storages[key]
else:
nbytes = numel * torch._utils._element_size(dtype)
typed_storage = load_tensor(dtype, nbytes, key,
_maybe_decode_ascii(location))
return typed_storage
load_module_mapping: Dict[str, str] = {
# See https://github.com/pytorch/pytorch/pull/51633
'torch.tensor': 'torch._tensor'
}
# Load the data (which may in turn use `persistent_load` to load tensors)
data_file = io.BytesIO(zip_file.get_record(pickle_file))
torch._utils._validate_loaded_sparse_tensors()
torch._C._log_api_usage_metadata(
"torch.load.metadata", {"serialization_id":
zip_file.serialization_id()}
)
return result
def _is_torchscript_zip(zip_file):
return 'constants.pkl' in zip_file.get_all_records()