Quaternion Class
- class fastquat.quaternion.Quaternion(w: float | Array, x: float | Array, y: float | Array, z: float | Array)[source]
Bases:
objectClass for manipulating quaternion tensors with JAX.
A quaternion is represented by [w, x, y, z] where w is the scalar part and (x, y, z) is the vector part.
- __init__(w: float | Array, x: float | Array, y: float | Array, z: float | Array) None[source]
Initialize a tensor of quaternions.
- Parameters:
w – components of the quaternions.
x – components of the quaternions.
y – components of the quaternions.
z – components of the quaternions.
- classmethod tree_unflatten(aux_data, children) Quaternion[source]
Unflatten The Quaternion PyTree
- classmethod from_array(array: Array) Quaternion[source]
Create a Quaternion array from a numeric array of shape (…, 4).
- Parameters:
array – array of shape (…, 4) where the last dimension is [w, x, y, z]
- classmethod from_scalar_vector(scalar: Array, vector: Array) Quaternion[source]
Create a quaternion from scalar and vector parts.
- Parameters:
scalar – Array of shape (…,) for the scalar part.
vector – Array of shape (…, 3) for the vector part.
- Returns:
Quaternion
- classmethod from_rotation_matrix(rot: Array) Quaternion[source]
Create the quaternion associated to a rotation matrix.
- Parameters:
rot – Array of shape (…, 3, 3) representing the rotation matrix
- Returns:
The normalized Quaternion tensor representing the rotation matrix.
- classmethod zeros(shape: tuple[int, ...] = (), dtype: ~numpy.dtype = <class 'jax.numpy.float32'>) Quaternion[source]
Create quaternions with all components set to 0.
- Parameters:
shape – Shape of the tensor (without the last dimension).
dtype – Data type of the quaternion components.
- Returns:
Quaternion with all components equal to 0.
- classmethod ones(shape: tuple[int, ...] = (), dtype: ~numpy.dtype = <class 'jax.numpy.float32'>) Quaternion[source]
Create quaternions with scalar component set to 1 and vector components set to 0.
- Parameters:
shape – Shape of the tensor (without the last dimension).
dtype – Data type of the quaternion components.
- Returns:
Quaternions with w=1 and x=y=z=0.
- classmethod full(shape: tuple[int, ...], fill_value: float, dtype: ~numpy.dtype = <class 'jax.numpy.float32'>) Quaternion[source]
Create quaternions with scalar component set to a value and vector components set to 0.
- Parameters:
shape – Shape of the tensor (without the last dimension).
fill_value – Value to fill the scalar component with.
dtype – Data type of the quaternion components.
- Returns:
Quaternions with w=fill_value and x=y=z=0.
- classmethod random(key: PRNGKey, shape: tuple[int, ...] = ()) Quaternion[source]
Generate normalized random quaternions.
- Parameters:
key – Key PRNG.
shape – Shape of the tensor (without the last dimension).
- Returns:
Normalized Quaternion.
- normalize() Quaternion[source]
Normalize the quaternion.
Returns the normalized quaternion. If the quaternion has zero norm, returns the zero quaternion [0, 0, 0, 0].
- to_rotation_matrix() Array[source]
Convert quaternion to rotation matrix.
- Returns:
Array of shape (…, 3, 3)
- rotate_vector(v: Array) Array[source]
Apply quaternion rotation to a vector.
- Parameters:
v – Array of shape (…, 3) representing vectors
- Returns:
Array of shape (…, 3) representing rotated vectors
- __neg__() Quaternion[source]
Quaternion negation.
- __add__(other: Any) Quaternion[source]
Quaternion addition.
- __radd__(other: Any) Quaternion[source]
Quaternion addition.
- __sub__(other: Any) Quaternion[source]
Quaternion subtraction.
- __rsub__(other: Any) Quaternion[source]
Quaternion subtraction.
- __mul__(other: Any) Quaternion[source]
Quaternion multiplication.
- __rmul__(other: Any) Quaternion[source]
Quaternion multiplication.
- __truediv__(other: Any) Quaternion[source]
Quaternion division.
- __rtruediv__(other: Any) Quaternion[source]
Quaternion division.
- __pow__(exponent: float | int | Array) Quaternion[source]
Quaternion exponentiation q^n.
For integer exponents, uses optimized special cases. For non-integer exponents, uses the general formula: q^n = exp(n * log(q))
- Parameters:
exponent – The exponent (scalar or array)
- Returns:
The quaternion raised to the given power
- log() Quaternion[source]
Compute quaternion logarithm.
For a quaternion q = ‖q‖ * (cos(θ) + sin(θ)v), the logarithm is: log(q) = log(‖q‖) + θ * v
- Returns:
The logarithm of the quaternion
- exp() Quaternion[source]
Compute quaternion exponential.
For a quaternion q = s + v, the exponential is: exp(q) = exp(s) * (cos(‖v‖) + sin(‖v‖) * v/‖v‖)
- Returns:
The exponential of the quaternion
- property ndim
Number of dimensions of the quaternion tensor (without the quaternion dimension).
- property size
Total number of quaternions.
- reshape(*shape) Quaternion[source]
Redimensionne le tableau de quaternions
- flatten() Quaternion[source]
Aplatis le tableau de quaternions
- ravel() Quaternion[source]
Aplatis le tableau de quaternions
- squeeze(axis=None) Quaternion[source]
Supprime les dimensions de taille 1
- conjugate() Quaternion[source]
Quaternion conjugate.
- conj() Quaternion[source]
Quaternion conjugate.
- property device: jax.Device[Any]
- devices() set[jax.Device[Any]][source]
- slerp(other: Quaternion, t: float | Array) Quaternion[source]
Spherical linear interpolation between two quaternions.
- Parameters:
other – Target quaternion to interpolate towards
t – Interpolation parameter in [0, 1]. t=0 returns self, t=1 returns other
- Returns:
Interpolated quaternion
The Quaternion class provides a comprehensive interface for quaternion operations
optimized for JAX. All methods are compatible with JAX transformations including JIT
compilation, automatic differentiation, and vectorization.
Constructor Methods
- Quaternion.__init__(w: float | Array, x: float | Array, y: float | Array, z: float | Array) None[source]
Initialize a tensor of quaternions.
- Parameters:
w – components of the quaternions.
x – components of the quaternions.
y – components of the quaternions.
z – components of the quaternions.
- classmethod Quaternion.from_array(array: Array) Quaternion[source]
Create a Quaternion array from a numeric array of shape (…, 4).
- Parameters:
array – array of shape (…, 4) where the last dimension is [w, x, y, z]
- classmethod Quaternion.from_scalar_vector(scalar: Array, vector: Array) Quaternion[source]
Create a quaternion from scalar and vector parts.
- Parameters:
scalar – Array of shape (…,) for the scalar part.
vector – Array of shape (…, 3) for the vector part.
- Returns:
Quaternion
- classmethod Quaternion.from_rotation_matrix(rot: Array) Quaternion[source]
Create the quaternion associated to a rotation matrix.
- Parameters:
rot – Array of shape (…, 3, 3) representing the rotation matrix
- Returns:
The normalized Quaternion tensor representing the rotation matrix.
- classmethod Quaternion.zeros(shape: tuple[int, ...] = (), dtype: ~numpy.dtype = <class 'jax.numpy.float32'>) Quaternion[source]
Create quaternions with all components set to 0.
- Parameters:
shape – Shape of the tensor (without the last dimension).
dtype – Data type of the quaternion components.
- Returns:
Quaternion with all components equal to 0.
- classmethod Quaternion.ones(shape: tuple[int, ...] = (), dtype: ~numpy.dtype = <class 'jax.numpy.float32'>) Quaternion[source]
Create quaternions with scalar component set to 1 and vector components set to 0.
- Parameters:
shape – Shape of the tensor (without the last dimension).
dtype – Data type of the quaternion components.
- Returns:
Quaternions with w=1 and x=y=z=0.
- classmethod Quaternion.full(shape: tuple[int, ...], fill_value: float, dtype: ~numpy.dtype = <class 'jax.numpy.float32'>) Quaternion[source]
Create quaternions with scalar component set to a value and vector components set to 0.
- Parameters:
shape – Shape of the tensor (without the last dimension).
fill_value – Value to fill the scalar component with.
dtype – Data type of the quaternion components.
- Returns:
Quaternions with w=fill_value and x=y=z=0.
- classmethod Quaternion.random(key: PRNGKey, shape: tuple[int, ...] = ()) Quaternion[source]
Generate normalized random quaternions.
- Parameters:
key – Key PRNG.
shape – Shape of the tensor (without the last dimension).
- Returns:
Normalized Quaternion.
Properties
- Quaternion.w
- Quaternion.x
- Quaternion.y
- Quaternion.z
- Quaternion.vector
Vector part (…, 3)
- Quaternion.shape
Shape of the tensor.
- Quaternion.dtype
Data type.
Core Operations
- Quaternion.normalize() Quaternion[source]
Normalize the quaternion.
Returns the normalized quaternion. If the quaternion has zero norm, returns the zero quaternion [0, 0, 0, 0].
- Quaternion.conjugate() Quaternion[source]
Quaternion conjugate.
- Quaternion.conj() Quaternion[source]
Quaternion conjugate.
Rotation Operations
Interpolation
- Quaternion.slerp(other: Quaternion, t: float | Array) Quaternion[source]
Spherical linear interpolation between two quaternions.
- Parameters:
other – Target quaternion to interpolate towards
t – Interpolation parameter in [0, 1]. t=0 returns self, t=1 returns other
- Returns:
Interpolated quaternion
Advanced Operations
- Quaternion.log() Quaternion[source]
Compute quaternion logarithm.
For a quaternion q = ‖q‖ * (cos(θ) + sin(θ)v), the logarithm is: log(q) = log(‖q‖) + θ * v
- Returns:
The logarithm of the quaternion
- Quaternion.exp() Quaternion[source]
Compute quaternion exponential.
For a quaternion q = s + v, the exponential is: exp(q) = exp(s) * (cos(‖v‖) + sin(‖v‖) * v/‖v‖)
- Returns:
The exponential of the quaternion
- Quaternion.__pow__(exponent: float | int | Array) Quaternion[source]
Quaternion exponentiation q^n.
For integer exponents, uses optimized special cases. For non-integer exponents, uses the general formula: q^n = exp(n * log(q))
- Parameters:
exponent – The exponent (scalar or array)
- Returns:
The quaternion raised to the given power
Array Operations
- Quaternion.reshape(*shape) Quaternion[source]
Redimensionne le tableau de quaternions
- Quaternion.flatten() Quaternion[source]
Aplatis le tableau de quaternions
- Quaternion.ravel() Quaternion[source]
Aplatis le tableau de quaternions
- Quaternion.squeeze(axis=None) Quaternion[source]
Supprime les dimensions de taille 1
Device and Memory
- Quaternion.device
- Quaternion.devices() set[jax.Device[Any]][source]
- Quaternion.nbytes
Number of bytes in the tensor.
- Quaternion.itemsize
Size of one quaternion element in bytes.
- Quaternion.size
Total number of quaternions.
- Quaternion.ndim
Number of dimensions of the quaternion tensor (without the quaternion dimension).