Tutorial SLERP
This tutorial shows how to perform Spherical Linear Interpolation (SLERP) between quaternions with an animated visualization showing the interpolation path on the unit sphere.
[1]:
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import HTML
from matplotlib.animation import FuncAnimation
from fastquat import Quaternion
# Configure matplotlib for better rendering
plt.rcParams['figure.dpi'] = 100
plt.rcParams['animation.html'] = 'html5'
Setup: Define Quaternions and Test Vector
We’ll interpolate between two quaternions that represent significantly different rotations to make the arc clearly visible.
[2]:
# Define two quaternions with a large angular separation for a visible arc
# q1: Identity (no rotation)
q1 = Quaternion(1.0)
# q2: 120 degree rotation around axis (1, 1, 1) for a large, visible arc
axis = jnp.array([1.0, 1.0, 1.0])
axis = axis / jnp.linalg.norm(axis) # Normalize
angle = 2 * jnp.pi / 3 # 120 degrees
q2 = Quaternion.from_scalar_vector(jnp.cos(angle / 2), jnp.sin(angle / 2) * axis)
print(f'Start quaternion q1: {q1}')
print(f'End quaternion q2: {q2}')
print(f'Dot product: {jnp.sum(q1.wxyz * q2.wxyz):.4f}')
print(
f'Angular separation: {jnp.degrees(2 * jnp.arccos(jnp.abs(jnp.sum(q1.wxyz * q2.wxyz)))):.2f}'
f' radians'
)
Start quaternion q1: 1.0 + 0.0i + 0.0j + 0.0k
End quaternion q2: 0.4999999701976776 + 0.5i + 0.5j + 0.5k
Dot product: 0.5000
Angular separation: 120.00 radians
Generate SLERP Interpolation
Create a smooth interpolation between the quaternions and apply them to a test vector.
[3]:
# Generate interpolation parameters for smooth animation
n_frames = 50
t_values = jnp.linspace(0, 1, n_frames)
# SLERP interpolation
slerp_quaternions = q1.slerp(q2, t_values)
# Test vector to rotate (unit vector along x-axis)
test_vector = jnp.array([1.0, 0.0, 0.0])
# Apply rotations to test vector
slerp_rotated = slerp_quaternions.rotate_vector(test_vector)
print(f'Generated {len(slerp_rotated)} frames for animation')
print(f'All quaternions normalized: {jnp.allclose(abs(slerp_quaternions), 1.0)}')
Generated 50 frames for animation
All quaternions normalized: True
Animated Display
Now create an animation showing the interpolation progressing along the arc.
[4]:
# Create animated visualization
fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection='3d')
# Draw unit sphere (static)
u = np.linspace(0, 2 * np.pi, 50)
v = np.linspace(0, np.pi, 50)
sphere_x = np.outer(np.cos(u), np.sin(v))
sphere_y = np.outer(np.sin(u), np.sin(v))
sphere_z = np.outer(np.ones(np.size(u)), np.cos(v))
ax.plot_surface(sphere_x, sphere_y, sphere_z, alpha=0.1, color='lightgray')
# Plot the complete path in light color
ax.plot(
slerp_rotated[:, 0],
slerp_rotated[:, 1],
slerp_rotated[:, 2],
'--',
color='lightblue',
linewidth=1,
alpha=0.5,
label='Full Path',
)
# Start and end points
ax.scatter(*slerp_rotated[0], color='green', s=100, label='Start', zorder=5)
ax.scatter(*slerp_rotated[-1], color='red', s=100, label='End', zorder=5)
# Initialize animated elements
current_point = ax.scatter([], [], [], color='blue', s=150, zorder=6)
(progress_line,) = ax.plot([], [], [], 'b-', linewidth=3, label='Progress')
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
ax.set_title('SLERP Animation on Unit Sphere')
ax.legend()
ax.set_box_aspect([1, 1, 1])
# Set fixed view limits
ax.set_xlim([-1.2, 1.2])
ax.set_ylim([-1.2, 1.2])
ax.set_zlim([-1.2, 1.2])
# Apply -90° rotation around k (z-axis) to orient towards viewer
ax.view_init(elev=20, azim=45)
# Initialize vector arrow variable
origin = [0, 0, 0]
vector_arrow = None
def animate(frame):
global vector_arrow
# Update current point
current_pos = slerp_rotated[frame]
current_point._offsets3d = ([current_pos[0]], [current_pos[1]], [current_pos[2]])
# Update progress line (path so far)
progress_line.set_data_3d(
slerp_rotated[: frame + 1, 0],
slerp_rotated[: frame + 1, 1],
slerp_rotated[: frame + 1, 2],
)
# Update vector arrow from origin to current point
# Remove previous arrow if it exists
if vector_arrow is not None:
vector_arrow.remove()
# Create new arrow
vector_arrow = ax.quiver(
*origin,
*current_pos,
color='red',
arrow_length_ratio=0.1,
linewidth=3,
alpha=0.7,
)
return current_point, progress_line
# Create animation - don't show the final static frame
anim = FuncAnimation(fig, animate, frames=n_frames, interval=100, blit=False, repeat=True)
# Close the static figure to prevent it from showing
plt.close(fig)
# Only display the HTML animation widget
HTML(anim.to_jshtml())
[4]: