Quick Start Guide#
This tutorial will guide you through the basics of using braintaichi to create high-performance brain dynamics operators.
Installation#
First, make sure you have installed braintaichi:
pip install braintaichi
Or install from source:
git clone https://github.com/chaoming0625/braintaichi.git
cd braintaichi
pip install -e .
Import Required Libraries#
Let’s start by importing the necessary libraries:
import sys
import numpy as np
import jax
import jax.numpy as jnp
import taichi as ti
from scipy.sparse import csr_matrix
import braintaichi as bti
Example 1: Simple Vector Addition#
Let’s start with a simple example - vector addition using Taichi kernel.
@ti.kernel
def vector_add(
a: ti.types.ndarray(ndim=1),
b: ti.types.ndarray(ndim=1),
out: ti.types.ndarray(ndim=1)
):
for i in range(a.shape[0]):
out[i] = a[i] + b[i]
# Register the custom operator
vector_add_op = bti.XLACustomOp(
cpu_kernel=vector_add,
gpu_kernel=vector_add
)
# Test the operator
n = 10
a = jnp.arange(n, dtype=jnp.float32)
b = jnp.ones(n, dtype=jnp.float32)
result = vector_add_op(
a, b,
outs=[jax.ShapeDtypeStruct((n,), dtype=jnp.float32)]
)
print("Input a:", a)
print("Input b:", b)
print("Result:", result)
Input a: [0. 1. 2. 3. 4. 5. 6. 7. 8. 9.]
Input b: [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
Result: [Array([ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10.], dtype=float32)]
Example 2: Sparse Matrix-Vector Multiplication#
Brain networks are typically sparse. Let’s implement a sparse matrix-vector multiplication operator using CSR (Compressed Sparse Row) format.
@ti.kernel
def csr_matvec(
values: ti.types.ndarray(ndim=1),
indices: ti.types.ndarray(ndim=1),
indptr: ti.types.ndarray(ndim=1),
vector: ti.types.ndarray(ndim=1),
out: ti.types.ndarray(ndim=1)
):
# Iterate over each row
for row in range(indptr.shape[0] - 1):
row_sum = 0.0
# Iterate over non-zero elements in the row
for j in range(indptr[row], indptr[row + 1]):
col = indices[j]
row_sum += values[j] * vector[col]
out[row] = row_sum
# Register the operator
csr_matvec_op = bti.XLACustomOp(
cpu_kernel=csr_matvec,
gpu_kernel=csr_matvec
)
# Create a sparse matrix
n_rows, n_cols = 100, 100
density = 0.1
dense_matrix = (np.random.rand(n_rows, n_cols) < density).astype(float)
dense_matrix *= np.random.rand(n_rows, n_cols)
# Convert to CSR format
sparse_matrix = csr_matrix(dense_matrix)
# Create input vector
input_vector = np.random.rand(n_cols).astype(np.float32)
# Run the custom operator
result = csr_matvec_op(
jnp.array(sparse_matrix.data, dtype=jnp.float32),
jnp.array(sparse_matrix.indices, dtype=jnp.int32),
jnp.array(sparse_matrix.indptr, dtype=jnp.int32),
jnp.array(input_vector, dtype=jnp.float32),
outs=[jax.ShapeDtypeStruct((n_rows,), dtype=jnp.float32)]
)
# Verify the result
expected = sparse_matrix @ input_vector
print("Custom operator result:", result[0][:5])
print("Expected result:", expected[:5])
print("Maximum difference:", np.max(np.abs(np.array(result[0]) - expected)))
Custom operator result: [1.8271513 2.4389558 1.7941285 2.318161 1.7279923]
Expected result: [1.82715139 2.43895573 1.79412849 2.31816102 1.72799216]
Maximum difference: 3.8372544519660323e-07
Example 3: Event-Driven Computation#
Brain dynamics are often event-driven. Let’s implement an event-driven synaptic transmission operator.
@ti.kernel
def event_csr_matvec(
values: ti.types.ndarray(ndim=1),
indices: ti.types.ndarray(ndim=1),
indptr: ti.types.ndarray(ndim=1),
events: ti.types.ndarray(ndim=1), # Boolean array indicating which neurons fired
out: ti.types.ndarray(ndim=1)
):
# Only process rows where events occurred
ti.loop_config(serialize=True)
for row in range(indptr.shape[0] - 1):
if events[row]: # Only process if neuron fired
for j in range(indptr[row], indptr[row + 1]):
col = indices[j]
out[col] += values[j]
# Register the operator
event_csr_op = bti.XLACustomOp(
cpu_kernel=event_csr_matvec,
gpu_kernel=event_csr_matvec
)
# Create test data
n_neurons = 1000
density = 0.1
# Create sparse connectivity matrix
connectivity = (np.random.rand(n_neurons, n_neurons) < density).astype(float)
connectivity *= np.random.rand(n_neurons, n_neurons) * 0.5 # Synaptic weights
sparse_conn = csr_matrix(connectivity)
# Generate random spike events (10% of neurons fire)
events = np.random.rand(n_neurons) < 0.1
# Run the event-driven operator
result = event_csr_op(
jnp.array(sparse_conn.data, dtype=jnp.float32),
jnp.array(sparse_conn.indices, dtype=jnp.int32),
jnp.array(sparse_conn.indptr, dtype=jnp.int32),
jnp.array(events, dtype=jnp.bool_),
outs=[jax.ShapeDtypeStruct((n_neurons,), dtype=jnp.float32)]
)
print(f"Number of neurons that fired: {events.sum()}")
print(f"Synaptic input statistics:")
print(f" Mean: {np.mean(result[0]):.4f}")
print(f" Max: {np.max(result[0]):.4f}")
print(f" Non-zero entries: {np.sum(np.array(result[0]) > 0)}")
Number of neurons that fired: 98
Synaptic input statistics:
Mean: 2.4380
Max: 5.0876
Non-zero entries: 1000
Example 4: Using Built-in Operators#
braintaichi provides many pre-implemented operators for common brain dynamics operations. Let’s explore some of them.
# Check available operators in braintaichi
print("Available modules in braintaichi:")
print([attr for attr in dir(bti) if not attr.startswith('_')])
Available modules in braintaichi:
['XLACustomOp', 'coo_to_csr', 'coomv', 'cpu_ops', 'csr_to_coo', 'csr_to_dense', 'csrmm', 'csrmv', 'defjvp', 'event_csrmm', 'event_csrmv', 'get_homo_weight_matrix', 'get_normal_weight_matrix', 'get_uniform_weight_matrix', 'jitc_event_mv_prob_homo', 'jitc_event_mv_prob_normal', 'jitc_event_mv_prob_uniform', 'jitc_mv_prob_homo', 'jitc_mv_prob_normal', 'jitc_mv_prob_uniform', 'rand', 'register_general_batching']
Performance Tips#
Here are some key tips for optimizing your custom operators:
Parallelize outer loops: Taichi automatically parallelizes the outermost for-loops
Use
ti.loop_config(serialize=True): When you need sequential execution or want to use break statementsChoose appropriate data types: Use
ti.f32for single precision,ti.f64for double precisionAvoid Python objects inside kernels: Use Taichi native types only
Batch operations: Process multiple operations together to reduce overhead
Integration with JAX#
One of the powerful features of braintaichi is seamless integration with JAX, enabling automatic differentiation and JIT compilation.
# Example: Using braintaichi operator in a JAX JIT-compiled function
@jax.jit
def neural_network_step(weights_data, weights_indices, weights_indptr, inputs):
"""Simulate one step of a spiking neural network"""
n = inputs.shape[0]
# Apply synaptic weights
synaptic_input = csr_matvec_op(
weights_data,
weights_indices,
weights_indptr,
inputs,
outs=[jax.ShapeDtypeStruct((n,), dtype=jnp.float32)]
)
return synaptic_input
# Test the JIT-compiled function
n = 100
sparse_mat = csr_matrix((np.random.rand(n, n) < 0.1).astype(float) * np.random.rand(n, n))
inputs = jnp.array(np.random.rand(n), dtype=jnp.float32)
result = neural_network_step(
jnp.array(sparse_mat.data, dtype=jnp.float32),
jnp.array(sparse_mat.indices, dtype=jnp.int32),
jnp.array(sparse_mat.indptr, dtype=jnp.int32),
inputs
)
print("Neural network output:", result[0][:5])
Neural network output: [4.365637 1.9171182 1.9304688 2.2530348 2.110351 ]
Next Steps#
Now that you’ve learned the basics, you can:
Read the braintaichi_intro.ipynb for detailed kernel registration interfaces
Explore the complete_example.ipynb for more complex use cases
Check out the advanced_optimization.ipynb for performance optimization techniques
Visit the API Documentation for detailed reference
For more examples, check the source code: