Getting Started
This guide will get you started with FastQuat’s core functionality.
Creating Quaternions
There are several ways to create quaternions:
[1]:
import jax
import jax.numpy as jnp
import jax.random as jr
from fastquat import Quaternion
# From components (w, x, y, z)
q1 = Quaternion(0.0, -1.0, -1.0, -1.0)
identity = Quaternion(1.0)
i = Quaternion(x=1)
j = Quaternion(y=1)
k = Quaternion(z=1)
# Convenience constructors
q_zeros = Quaternion.zeros((2, 3))
q_ones = Quaternion.ones((2, 3))
q_twos = Quaternion.full((2, 3), 2)
# From arrays
array = jnp.array([1.0, 0.0, 0.0, 0.0])
q2 = Quaternion.from_array(array)
# Random normalized quaternions
key = jr.key(42)
q_random = Quaternion.random(key)
Basic Operations
Quaternions support standard mathematical operations:
[2]:
q = Quaternion(1.0, 0.1, 0.2, 0.3)
p = 2
# Arithmetic
q_sum = q1 + q2
q_diff = q1 - q2
q_product = q1 * q2
q_power = q**p
# Normalization
norm = abs(q) # Quaternion norm
q_unit = q.normalize() # Unit quaternion
# Conjugation and inverse
q_conj = q.conj() # Conjugate
q_inv = 1 / q # Inverse, or q ** -1
# Other operations
q_log = q.log()
q_exp = q.exp()
Vector Rotation
One of the most common uses of quaternions is rotating 3D vectors:
[3]:
# Create a 90° rotation around the z-axis
angle = jnp.pi / 2
q_rot = Quaternion(jnp.cos(angle / 2), 0.0, 0.0, jnp.sin(angle / 2))
# Rotate a vector
vector = jnp.array([1.0, 0.0, 0.0]) # Unit vector along x
rotated = q_rot.rotate_vector(vector)
print(rotated) # Should be approximately [0, 1, 0]
[0. 1. 0.]
Conversion to/from Rotation Matrices
FastQuat can convert between quaternions and rotation matrices:
[4]:
# Quaternion to rotation matrix
R = q_rot.to_rotation_matrix()
print(R.shape) # (3, 3)
# Rotation matrix to quaternion
q_from_matrix = Quaternion.from_rotation_matrix(R)
(3, 3)
Spherical Linear Interpolation (SLERP)
SLERP provides smooth interpolation between quaternions:
[5]:
# Two different orientations
q_start = Quaternion(1.0) # Identity
q_end = Quaternion(0.7071, 0.7071, 0.0, 0.0) # 90° around x
# Interpolate between them
t = 0.5 # Halfway point
q_mid = q_start.slerp(q_end, t)
# Batch interpolation
t_values = jnp.linspace(0, 1, 10)
interpolated = q_start.slerp(q_end, t_values)
print(interpolated.shape) # (10,) - 10 quaternions
(10,)
JAX Integration
FastQuat is fully compatible with JAX transformations:
[6]:
# JIT compilation
@jax.jit
def rotate_and_normalize(q, v):
rotated = q.rotate_vector(v)
return rotated / jnp.linalg.norm(rotated)
# Vectorization
batch_rotate = jax.vmap(lambda q, v: q.rotate_vector(v))
# Create batches
q_batch = Quaternion.random(key, shape=(100,))
v_batch = jax.random.normal(key, (100, 3))
# Process entire batch at once
rotated_batch = batch_rotate(q_batch, v_batch)
# Automatic differentiation
def loss_function(q_params):
q = Quaternion.from_array(q_params)
rotated = q.rotate_vector(vector)
return jnp.sum(rotated**2)
grad_fn = jax.grad(loss_function)
gradients = grad_fn(jnp.array([1.0, 0.1, 0.1, 0.1]))
Performance Tips
Use JIT compilation for repeated operations:
[7]:
@jax.jit
def batch_operation(quaternions):
return quaternions.normalize()
Prefer batch operations over loops:
[8]:
# Good: vectorized operation
results = q_batch.rotate_vector(v_batch)
# Avoid: Python loops
# results = [q.rotate_vector(v) for q, v in zip(q_batch, v_batch)]
Normalize quaternions when needed for rotations:
[9]:
q_unit = q.normalize() # Ensure unit quaternion for rotations
Next Steps
Explore the tutorials for detailed use cases
Check the API reference for complete API documentation
See advanced interpolation techniques with SLERP