Advanced Optimization Techniques#
This tutorial covers advanced optimization techniques for writing high-performance brain dynamics operators with braintaichi.
import numpy as np
import jax
import jax.numpy as jnp
import taichi as ti
import time
from scipy.sparse import csr_matrix
import braintaichi as bti
1. Loop Configuration and Parallelization#
Taichi automatically parallelizes outer loops, but you can fine-tune the parallelization behavior for optimal performance.
1.1 Serial vs Parallel Execution#
# Serial execution (useful when order matters or using break statements)
@ti.kernel
def serial_sum(
arr: ti.types.ndarray(ndim=1),
out: ti.types.ndarray(ndim=1)
):
total = 0.0
ti.loop_config(serialize=True)
for i in range(arr.shape[0]):
total += arr[i]
if total > 100.0: # Can use break in serial loops
break
out[0] = total
# Parallel execution (default for outer loops)
@ti.kernel
def parallel_multiply(
arr: ti.types.ndarray(ndim=1),
scalar: ti.f32,
out: ti.types.ndarray(ndim=1)
):
# This loop is automatically parallelized
for i in range(arr.shape[0]):
out[i] = arr[i] * scalar
1.2 Block Dimension Tuning for GPU#
@ti.kernel
def optimized_gpu_kernel(
data: ti.types.ndarray(ndim=1),
out: ti.types.ndarray(ndim=1)
):
# Configure block dimension for GPU
# Common values: 64, 128, 256, 512
ti.loop_config(block_dim=256)
for i in range(data.shape[0]):
out[i] = ti.sqrt(data[i]) + ti.sin(data[i])
# For CPU, configure number of parallel threads
@ti.kernel
def optimized_cpu_kernel(
data: ti.types.ndarray(ndim=1),
out: ti.types.ndarray(ndim=1)
):
ti.loop_config(parallelize=8) # Use 8 threads
for i in range(data.shape[0]):
out[i] = ti.sqrt(data[i]) + ti.sin(data[i])
2. Memory Access Optimization#
Efficient memory access patterns are crucial for performance.
2.1 Coalesced Memory Access#
# BAD: Non-coalesced access (strided access)
@ti.kernel
def bad_memory_access(
matrix: ti.types.ndarray(ndim=2),
out: ti.types.ndarray(ndim=1)
):
n, m = matrix.shape
for col in range(m):
total = 0.0
for row in range(n):
total += matrix[row, col] # Column-wise access (non-coalesced)
out[col] = total
# GOOD: Coalesced access (contiguous access)
@ti.kernel
def good_memory_access(
matrix: ti.types.ndarray(ndim=2),
out: ti.types.ndarray(ndim=1)
):
n, m = matrix.shape
for row in range(n):
for col in range(m):
out[col] += matrix[row, col] # Row-wise access (coalesced)
2.2 Using Local Variables#
# BAD: Multiple global memory accesses
@ti.kernel
def bad_global_access(
arr: ti.types.ndarray(ndim=1),
out: ti.types.ndarray(ndim=1)
):
for i in range(arr.shape[0]):
out[i] = arr[i] * arr[i] + arr[i] * 2.0 # arr[i] accessed 3 times
# GOOD: Use local variable to cache value
@ti.kernel
def good_local_cache(
arr: ti.types.ndarray(ndim=1),
out: ti.types.ndarray(ndim=1)
):
for i in range(arr.shape[0]):
val = arr[i] # Load once into local variable
out[i] = val * val + val * 2.0
3. Optimizing Sparse Operations#
Sparse operations are common in brain dynamics. Here are optimization strategies for sparse matrices.
3.1 Optimized CSR Matrix-Vector Multiplication#
# Standard implementation
@ti.kernel
def csr_matvec_standard(
values: ti.types.ndarray(ndim=1),
indices: ti.types.ndarray(ndim=1),
indptr: ti.types.ndarray(ndim=1),
vector: ti.types.ndarray(ndim=1),
out: ti.types.ndarray(ndim=1)
):
for row in range(indptr.shape[0] - 1):
row_sum = 0.0
for j in range(indptr[row], indptr[row + 1]):
row_sum += values[j] * vector[indices[j]]
out[row] = row_sum
# Optimized implementation with local caching
@ti.kernel
def csr_matvec_optimized(
values: ti.types.ndarray(ndim=1),
indices: ti.types.ndarray(ndim=1),
indptr: ti.types.ndarray(ndim=1),
vector: ti.types.ndarray(ndim=1),
out: ti.types.ndarray(ndim=1)
):
for row in range(indptr.shape[0] - 1):
row_sum = 0.0
start = indptr[row]
end = indptr[row + 1]
for j in range(start, end):
col = indices[j]
val = values[j]
vec_val = vector[col]
row_sum += val * vec_val
out[row] = row_sum
# Benchmark the two implementations
def benchmark_csr_matvec():
n = 10000
density = 0.01
# Create sparse matrix
dense_mat = (np.random.rand(n, n) < density).astype(float) * np.random.rand(n, n)
sparse_mat = csr_matrix(dense_mat)
vector = np.random.rand(n).astype(np.float32)
# Register operators
op_standard = bti.XLACustomOp(cpu_kernel=csr_matvec_standard, gpu_kernel=csr_matvec_standard)
op_optimized = bti.XLACustomOp(cpu_kernel=csr_matvec_optimized, gpu_kernel=csr_matvec_optimized)
# Prepare inputs
values = jnp.array(sparse_mat.data, dtype=jnp.float32)
indices = jnp.array(sparse_mat.indices, dtype=jnp.int32)
indptr = jnp.array(sparse_mat.indptr, dtype=jnp.int32)
vec = jnp.array(vector, dtype=jnp.float32)
# Warm up
for _ in range(3):
_ = op_standard(values, indices, indptr, vec,
outs=[jax.ShapeDtypeStruct((n,), dtype=jnp.float32)])
_ = op_optimized(values, indices, indptr, vec,
outs=[jax.ShapeDtypeStruct((n,), dtype=jnp.float32)])
# Benchmark
n_runs = 10
start = time.time()
for _ in range(n_runs):
_ = op_standard(values, indices, indptr, vec,
outs=[jax.ShapeDtypeStruct((n,), dtype=jnp.float32)])
time_standard = (time.time() - start) / n_runs
start = time.time()
for _ in range(n_runs):
_ = op_optimized(values, indices, indptr, vec,
outs=[jax.ShapeDtypeStruct((n,), dtype=jnp.float32)])
time_optimized = (time.time() - start) / n_runs
print(f"Standard implementation: {time_standard*1000:.3f} ms")
print(f"Optimized implementation: {time_optimized*1000:.3f} ms")
print(f"Speedup: {time_standard/time_optimized:.2f}x")
benchmark_csr_matvec()
Standard implementation: 0.200 ms
Optimized implementation: 0.300 ms
Speedup: 0.67x
4. Event-Driven Optimization#
Event-driven computations are essential for spiking neural networks. Here’s how to optimize them.
# Method 1: Direct event checking (simple but may have branch divergence)
@ti.kernel
def event_driven_v1(
indices: ti.types.ndarray(ndim=1),
indptr: ti.types.ndarray(ndim=1),
events: ti.types.ndarray(ndim=1),
weight: ti.f32,
out: ti.types.ndarray(ndim=1)
):
ti.loop_config(serialize=True)
for row in range(indptr.shape[0] - 1):
if events[row]: # Check event
for j in range(indptr[row], indptr[row + 1]):
out[indices[j]] += weight
# Method 2: Event filtering (better for low firing rates)
@ti.kernel
def event_driven_v2(
indices: ti.types.ndarray(ndim=1),
indptr: ti.types.ndarray(ndim=1),
event_indices: ti.types.ndarray(ndim=1), # Indices of neurons that fired
weight: ti.f32,
out: ti.types.ndarray(ndim=1)
):
# Only iterate over neurons that actually fired
ti.loop_config(serialize=True)
for i in range(event_indices.shape[0]):
row = event_indices[i]
for j in range(indptr[row], indptr[row + 1]):
out[indices[j]] += weight
# Benchmark event-driven methods
def benchmark_event_driven():
n = 10000
density = 0.01
firing_rate = 0.05 # 5% of neurons fire
# Create connectivity
conn = (np.random.rand(n, n) < density).astype(float)
sparse_conn = csr_matrix(conn)
# Create events
events = np.random.rand(n) < firing_rate
event_indices = np.where(events)[0].astype(np.int32)
# Register operators
op_v1 = bti.XLACustomOp(cpu_kernel=event_driven_v1, gpu_kernel=event_driven_v1)
op_v2 = bti.XLACustomOp(cpu_kernel=event_driven_v2, gpu_kernel=event_driven_v2)
print(f"Number of neurons: {n}")
print(f"Connectivity density: {density}")
print(f"Firing rate: {firing_rate}")
print(f"Neurons that fired: {len(event_indices)}")
print("\nMethod 1: Direct event checking")
print("Method 2: Event filtering (recommended for low firing rates)")
benchmark_event_driven()
Number of neurons: 10000
Connectivity density: 0.01
Firing rate: 0.05
Neurons that fired: 511
Method 1: Direct event checking
Method 2: Event filtering (recommended for low firing rates)
5. Data Type Optimization#
Choosing the right data types can significantly impact performance and memory usage.
# Using different precision levels
@ti.kernel
def compute_f32(
arr: ti.types.ndarray(ndim=1),
out: ti.types.ndarray(ndim=1)
):
for i in range(arr.shape[0]):
val = ti.cast(arr[i], ti.f32) # Explicitly cast to float32
out[i] = ti.sqrt(val) + ti.sin(val)
@ti.kernel
def compute_f64(
arr: ti.types.ndarray(ndim=1),
out: ti.types.ndarray(ndim=1)
):
for i in range(arr.shape[0]):
val = ti.cast(arr[i], ti.f64) # Explicitly cast to float64
out[i] = ti.sqrt(val) + ti.sin(val)
# Using integer types for indices
@ti.kernel
def gather_operation(
data: ti.types.ndarray(ndim=1),
indices: ti.types.ndarray(ndim=1),
out: ti.types.ndarray(ndim=1)
):
for i in range(indices.shape[0]):
idx = ti.cast(indices[i], ti.i32) # Use int32 for indices
out[i] = data[idx]
6. Atomic Operations for Race Condition Handling#
When multiple threads need to update the same memory location, use atomic operations.
# Without atomic operations (may have race conditions)
@ti.kernel
def scatter_add_unsafe(
values: ti.types.ndarray(ndim=1),
indices: ti.types.ndarray(ndim=1),
out: ti.types.ndarray(ndim=1)
):
for i in range(values.shape[0]):
idx = indices[i]
out[idx] += values[i] # Race condition if multiple threads write to same idx
# With atomic operations (safe)
@ti.kernel
def scatter_add_safe(
values: ti.types.ndarray(ndim=1),
indices: ti.types.ndarray(ndim=1),
out: ti.types.ndarray(ndim=1)
):
for i in range(values.shape[0]):
idx = indices[i]
ti.atomic_add(out[idx], values[i]) # Thread-safe atomic addition
7. Helper Functions for Code Reusability#
Use @ti.func to create reusable helper functions.
# Define helper functions
@ti.func
def relu(x: ti.f32) -> ti.f32:
return ti.max(0.0, x)
@ti.func
def sigmoid(x: ti.f32) -> ti.f32:
return 1.0 / (1.0 + ti.exp(-x))
@ti.func
def leaky_relu(x: ti.f32, alpha: ti.f32) -> ti.f32:
return ti.max(alpha * x, x)
# Use helper functions in kernels
@ti.kernel
def apply_activation(
data: ti.types.ndarray(ndim=1),
activation_type: ti.i32,
out: ti.types.ndarray(ndim=1)
):
for i in range(data.shape[0]):
x = data[i]
if activation_type == 0:
out[i] = relu(x)
elif activation_type == 1:
out[i] = sigmoid(x)
else:
out[i] = leaky_relu(x, 0.01)
8. Best Practices Summary#
Performance Checklist:#
Loop Configuration
Use parallel loops for independent operations
Use serial loops when order matters or for break statements
Tune
block_dimfor GPU (try 128, 256, 512)
Memory Access
Prefer coalesced (contiguous) memory access
Cache frequently accessed values in local variables
Minimize global memory accesses
Data Types
Use
ti.f32instead ofti.f64when precision allowsUse
ti.i32for indicesExplicit casting for clarity
Sparse Operations
Store data in efficient formats (CSR, COO)
Use event filtering for low firing rates
Cache row/column indices when possible
Synchronization
Use atomic operations for concurrent writes
Avoid unnecessary synchronization
Code Organization
Extract common operations into
@ti.funchelpersKeep kernels focused and modular
Profile before optimizing
Profiling Tips:#
# Enable profiling
import time
# Warm up
for _ in range(3):
result = your_operator(...)
# Measure
start = time.time()
for _ in range(100):
result = your_operator(...)
elapsed = (time.time() - start) / 100
print(f"Average time: {elapsed*1000:.3f} ms")
9. Real-World Example: Optimized Spiking Neural Network Layer#
Let’s combine all the optimization techniques into a complete, optimized SNN layer implementation.
# Helper functions for neuron dynamics
@ti.func
def lif_dynamics(v: ti.f32, current: ti.f32, tau: ti.f32, dt: ti.f32) -> ti.f32:
"""Leaky Integrate-and-Fire neuron dynamics"""
return v + ((-v + current) / tau) * dt
@ti.func
def check_spike(v: ti.f32, threshold: ti.f32) -> ti.i32:
"""Check if neuron spikes"""
return 1 if v >= threshold else 0
# Optimized SNN layer kernel
@ti.kernel
def snn_layer_optimized(
# Synaptic connectivity (CSR format)
syn_values: ti.types.ndarray(ndim=1),
syn_indices: ti.types.ndarray(ndim=1),
syn_indptr: ti.types.ndarray(ndim=1),
# Input spikes from previous layer
input_spikes: ti.types.ndarray(ndim=1),
# Neuron states
membrane_v: ti.types.ndarray(ndim=1),
# Parameters (as 0-dim arrays)
tau: ti.types.ndarray(),
threshold: ti.types.ndarray(),
dt: ti.types.ndarray(),
# Outputs
output_spikes: ti.types.ndarray(ndim=1),
new_membrane_v: ti.types.ndarray(ndim=1)
):
n_neurons = membrane_v.shape[0]
tau_val = tau[None]
threshold_val = threshold[None]
dt_val = dt[None]
# Step 1: Compute synaptic currents (parallelized)
for post_neuron in range(n_neurons):
# Accumulate synaptic input
synaptic_current = 0.0
start = syn_indptr[post_neuron]
end = syn_indptr[post_neuron + 1]
for j in range(start, end):
pre_neuron = syn_indices[j]
if input_spikes[pre_neuron] > 0.5: # Check if pre-synaptic neuron spiked
synaptic_current += syn_values[j]
# Step 2: Update membrane potential
v_old = membrane_v[post_neuron]
v_new = lif_dynamics(v_old, synaptic_current, tau_val, dt_val)
# Step 3: Check for spike and reset
spike = check_spike(v_new, threshold_val)
if spike:
v_new = 0.0 # Reset after spike
# Write outputs
output_spikes[post_neuron] = ti.cast(spike, ti.f32)
new_membrane_v[post_neuron] = v_new
# Register the operator
snn_layer_op = bti.XLACustomOp(
cpu_kernel=snn_layer_optimized,
gpu_kernel=snn_layer_optimized
)
# Test the optimized SNN layer
def test_snn_layer():
n_pre = 1000
n_post = 800
density = 0.1
# Create random connectivity
conn = (np.random.rand(n_post, n_pre) < density).astype(float)
conn *= np.random.rand(n_post, n_pre) * 0.5 # Random weights
sparse_conn = csr_matrix(conn)
# Initial states
input_spikes = (np.random.rand(n_pre) < 0.1).astype(np.float32)
membrane_v = np.random.rand(n_post).astype(np.float32) * 0.5
# Parameters (as 0-dim arrays)
tau = jnp.array(10.0, dtype=jnp.float32)
threshold = jnp.array(1.0, dtype=jnp.float32)
dt = jnp.array(0.1, dtype=jnp.float32)
# Run the SNN layer
output_spikes, new_v = snn_layer_op(
jnp.array(sparse_conn.data, dtype=jnp.float32),
jnp.array(sparse_conn.indices, dtype=jnp.int32),
jnp.array(sparse_conn.indptr, dtype=jnp.int32),
jnp.array(input_spikes, dtype=jnp.float32),
jnp.array(membrane_v, dtype=jnp.float32),
tau, threshold, dt,
outs=[
jax.ShapeDtypeStruct((n_post,), dtype=jnp.float32),
jax.ShapeDtypeStruct((n_post,), dtype=jnp.float32)
]
)
print(f"Input spikes: {input_spikes.sum()}/{n_pre}")
print(f"Output spikes: {output_spikes.sum()}/{n_post}")
print(f"Mean membrane potential: {new_v.mean():.4f}")
print(f"Max membrane potential: {new_v.max():.4f}")
test_snn_layer()
Input spikes: 107.0/1000
Output spikes: 0.0/800
Mean membrane potential: 0.2725
Max membrane potential: 0.5309
Conclusion#
This tutorial covered advanced optimization techniques for braintaichi:
Loop configuration and parallelization strategies
Memory access patterns and caching
Sparse operation optimization
Event-driven computation strategies
Data type selection
Atomic operations for thread safety
Code organization with helper functions
Complete optimized SNN layer example
For more information: