Welcome! For new projects I now strongly recommend using my newer jaxtyping project instead. It supports PyTorch, doesn't actually depend on JAX, and unlike TorchTyping it is compatible with static type checkers. The 'jax' in the name is now historical!
The original torchtyping README is as follows.
Turn this:
def batch_outer_product(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
# x has shape (batch, x_channels)
# y has shape (batch, y_channels)
# return has shape (batch, x_channels, y_channels)
return x.unsqueeze(-1) * y.unsqueeze(-2)into this:
def batch_outer_product(x: TensorType["batch", "x_channels"],
y: TensorType["batch", "y_channels"]
) -> TensorType["batch", "x_channels", "y_channels"]:
return x.unsqueeze(-1) * y.unsqueeze(-2)with programmatic checking that the shape (dtype, ...) specification is met.
Bye-bye bugs! Say hello to enforced, clear documentation of your code.
If (like me) you find yourself littering your code with comments like # x has shape (batch, hidden_state) or statements like assert x.shape == y.shape , just to keep track of what shape everything is, then this is for you.
pip install torchtypingRequires Python >=3.7 and PyTorch >=1.7.0.
If using typeguard then it must be a version <3.0.0.
torchtyping allows for type annotating:
- shape: size, number of dimensions;
- dtype (float, integer, etc.);
- layout (dense, sparse);
- names of dimensions as per named tensors;
- arbitrary number of batch dimensions with
...; - ...plus anything else you like, as
torchtypingis highly extensible.
If typeguard is (optionally) installed then at runtime the types can be checked to ensure that the tensors really are of the advertised shape, dtype, etc.
# EXAMPLE
from torch import rand
from torchtyping import TensorType, patch_typeguard
from typeguard import typechecked
patch_typeguard() # use before @typechecked
@typechecked
def func(x: TensorType["batch"],
y: TensorType["batch"]) -> TensorType["batch"]:
return x + y
func(rand(3), rand(3)) # works
func(rand(3), rand(1))
# TypeError: Dimension 'batch' of inconsistent size. Got both 1 and 3.typeguard also has an import hook that can be used to automatically test an entire module, without needing to manually add @typeguard.typechecked decorators.
If you're not using typeguard then torchtyping.patch_typeguard() can be omitted altogether, and torchtyping just used for documentation purposes. If you're not already using typeguard for your regular Python programming, then strongly consider using it. It's a great way to squash bugs. Both typeguard and torchtyping also integrate with pytest, so if you're concerned about any performance penalty then they can be enabled during tests only.
torchtyping.TensorType[shape, dtype, layout, details]The core of the library.
Each of shape, dtype, layout, details are optional.
- The
shapeargument can be any of:- An
int: the dimension must be of exactly this size. If it is-1then any size is allowed. - A
str: the size of the dimension passed at runtime will be bound to this name, and all tensors checked that the sizes are consistent. - A
...: An arbitrary number of dimensions of any sizes. - A
str: intpair (technically it's a slice), combining bothstrandintbehaviour. (Just astron its own is equivalent tostr: -1.) - A
str: strpair, in which case the size of the dimension passed at runtime will be bound to both names, and all dimensions with either name must have the same size. (Some people like to use this as a way to associate multiple names with a dimension, for extra documentation purposes.) - A
str: ...pair, in which case the multiple dimensions corresponding to...will be bound to the name specified bystr, and again checked for consistency between arguments. None, which when used in conjunction withis_namedbelow, indicates a dimension that must not have a name in the sense of named tensors.- A
None: intpair, combining bothNoneandintbehaviour. (Just aNoneon its own is equivalent toNone: -1.) - A
None: strpair, combining bothNoneandstrbehaviour. (That is, it must not have a named dimension, but must be of a size consistent with other uses of the string.) - A
typing.Any: Any size is allowed for this dimension (equivalent to-1). - Any tuple of the above. For example.
TensorType["batch": ..., "length": 10, "channels", -1]. If you just want to specify the number of dimensions then use for exampleTensorType[-1, -1, -1]for a three-dimensional tensor.
- An
- The
dtypeargument can be any of:torch.float32,torch.float64etc.int,bool,float, which are converted to their corresponding PyTorch types.floatis specifically interpreted astorch.get_default_dtype(), which is usuallyfloat32.
- The
layoutargument can be eithertorch.stridedortorch.sparse_coo, for dense and sparse tensors respectively. - The
detailsargument offers a way to pass an arbitrary number of additional flags that customise and extendtorchtyping. Two flags are built-in by default.torchtyping.is_namedcauses the names of tensor dimensions to be checked, andtorchtyping.is_floatcan be used to check that arbitrary floating point types are passed in. (Rather than just a specific one as with e.g.TensorType[torch.float32].) For discussion on how to customisetorchtypingwith your owndetails, see the further documentation. - Check multiple things at once by just putting them all together inside a single
[]. For exampleTensorType["batch": ..., "length", "channels", float, is_named].
torchtyping.patch_typeguard()torchtyping integrates with typeguard to perform runtime type checking. torchtyping.patch_typeguard() should be called at the global level, and will patch typeguard to check TensorTypes.
This function is safe to run multiple times. (It does nothing after the first run).
- If using
@typeguard.typechecked, thentorchtyping.patch_typeguard()should be called any time before using@typeguard.typechecked. For example you could call it at the start of each file usingtorchtyping. - If using
typeguard.importhook.install_import_hook, thentorchtyping.patch_typeguard()should be called any time before defining the functions you want checked. For example you could calltorchtyping.patch_typeguard()just once, at the same time as thetypeguardimport hook. (The order of the hook and the patch doesn't matter.) - If you're not using
typeguardthentorchtyping.patch_typeguard()can be omitted altogether, andtorchtypingjust used for documentation purposes.
pytest --torchtyping-patch-typeguardtorchtyping offers a pytest plugin to automatically run torchtyping.patch_typeguard() before your tests. pytest will automatically discover the plugin, you just need to pass the --torchtyping-patch-typeguard flag to enable it. Packages can then be passed to typeguard as normal, either by using @typeguard.typechecked, typeguard's import hook, or the pytest flag --typeguard-packages="your_package_here".
See the further documentation for:
- FAQ;
- Including
flake8andmypycompatibility;
- Including
- How to write custom extensions to
torchtyping; - Resources and links to other libraries and materials on this topic;
- More examples.