Quick Start Guide#

This tutorial will guide you through the basics of using braintaichi to create high-performance brain dynamics operators.

Installation#

First, make sure you have installed braintaichi:

pip install braintaichi

Or install from source:

git clone https://github.com/chaoming0625/braintaichi.git
cd braintaichi
pip install -e .

Import Required Libraries#

Let’s start by importing the necessary libraries:

import sys
import numpy as np
import jax
import jax.numpy as jnp
import taichi as ti
from scipy.sparse import csr_matrix

import braintaichi as bti

Example 1: Simple Vector Addition#

Let’s start with a simple example - vector addition using Taichi kernel.

@ti.kernel
def vector_add(
    a: ti.types.ndarray(ndim=1),
    b: ti.types.ndarray(ndim=1),
    out: ti.types.ndarray(ndim=1)
):
    for i in range(a.shape[0]):
        out[i] = a[i] + b[i]

# Register the custom operator
vector_add_op = bti.XLACustomOp(
    cpu_kernel=vector_add,
    gpu_kernel=vector_add
)
# Test the operator
n = 10
a = jnp.arange(n, dtype=jnp.float32)
b = jnp.ones(n, dtype=jnp.float32)

result = vector_add_op(
    a, b,
    outs=[jax.ShapeDtypeStruct((n,), dtype=jnp.float32)]
)

print("Input a:", a)
print("Input b:", b)
print("Result:", result)
Input a: [0. 1. 2. 3. 4. 5. 6. 7. 8. 9.]
Input b: [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
Result: [Array([ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10.], dtype=float32)]

Example 2: Sparse Matrix-Vector Multiplication#

Brain networks are typically sparse. Let’s implement a sparse matrix-vector multiplication operator using CSR (Compressed Sparse Row) format.

@ti.kernel
def csr_matvec(
    values: ti.types.ndarray(ndim=1),
    indices: ti.types.ndarray(ndim=1),
    indptr: ti.types.ndarray(ndim=1),
    vector: ti.types.ndarray(ndim=1),
    out: ti.types.ndarray(ndim=1)
):
    # Iterate over each row
    for row in range(indptr.shape[0] - 1):
        row_sum = 0.0
        # Iterate over non-zero elements in the row
        for j in range(indptr[row], indptr[row + 1]):
            col = indices[j]
            row_sum += values[j] * vector[col]
        out[row] = row_sum

# Register the operator
csr_matvec_op = bti.XLACustomOp(
    cpu_kernel=csr_matvec,
    gpu_kernel=csr_matvec
)
# Create a sparse matrix
n_rows, n_cols = 100, 100
density = 0.1
dense_matrix = (np.random.rand(n_rows, n_cols) < density).astype(float)
dense_matrix *= np.random.rand(n_rows, n_cols)

# Convert to CSR format
sparse_matrix = csr_matrix(dense_matrix)

# Create input vector
input_vector = np.random.rand(n_cols).astype(np.float32)

# Run the custom operator
result = csr_matvec_op(
    jnp.array(sparse_matrix.data, dtype=jnp.float32),
    jnp.array(sparse_matrix.indices, dtype=jnp.int32),
    jnp.array(sparse_matrix.indptr, dtype=jnp.int32),
    jnp.array(input_vector, dtype=jnp.float32),
    outs=[jax.ShapeDtypeStruct((n_rows,), dtype=jnp.float32)]
)

# Verify the result
expected = sparse_matrix @ input_vector
print("Custom operator result:", result[0][:5])
print("Expected result:", expected[:5])
print("Maximum difference:", np.max(np.abs(np.array(result[0]) - expected)))
Custom operator result: [1.8271513 2.4389558 1.7941285 2.318161  1.7279923]
Expected result: [1.82715139 2.43895573 1.79412849 2.31816102 1.72799216]
Maximum difference: 3.8372544519660323e-07

Example 3: Event-Driven Computation#

Brain dynamics are often event-driven. Let’s implement an event-driven synaptic transmission operator.

@ti.kernel
def event_csr_matvec(
    values: ti.types.ndarray(ndim=1),
    indices: ti.types.ndarray(ndim=1),
    indptr: ti.types.ndarray(ndim=1),
    events: ti.types.ndarray(ndim=1),  # Boolean array indicating which neurons fired
    out: ti.types.ndarray(ndim=1)
):
    # Only process rows where events occurred
    ti.loop_config(serialize=True)
    for row in range(indptr.shape[0] - 1):
        if events[row]:  # Only process if neuron fired
            for j in range(indptr[row], indptr[row + 1]):
                col = indices[j]
                out[col] += values[j]

# Register the operator
event_csr_op = bti.XLACustomOp(
    cpu_kernel=event_csr_matvec,
    gpu_kernel=event_csr_matvec
)
# Create test data
n_neurons = 1000
density = 0.1

# Create sparse connectivity matrix
connectivity = (np.random.rand(n_neurons, n_neurons) < density).astype(float)
connectivity *= np.random.rand(n_neurons, n_neurons) * 0.5  # Synaptic weights
sparse_conn = csr_matrix(connectivity)

# Generate random spike events (10% of neurons fire)
events = np.random.rand(n_neurons) < 0.1

# Run the event-driven operator
result = event_csr_op(
    jnp.array(sparse_conn.data, dtype=jnp.float32),
    jnp.array(sparse_conn.indices, dtype=jnp.int32),
    jnp.array(sparse_conn.indptr, dtype=jnp.int32),
    jnp.array(events, dtype=jnp.bool_),
    outs=[jax.ShapeDtypeStruct((n_neurons,), dtype=jnp.float32)]
)

print(f"Number of neurons that fired: {events.sum()}")
print(f"Synaptic input statistics:")
print(f"  Mean: {np.mean(result[0]):.4f}")
print(f"  Max: {np.max(result[0]):.4f}")
print(f"  Non-zero entries: {np.sum(np.array(result[0]) > 0)}")
Number of neurons that fired: 98
Synaptic input statistics:
  Mean: 2.4380
  Max: 5.0876
  Non-zero entries: 1000

Example 4: Using Built-in Operators#

braintaichi provides many pre-implemented operators for common brain dynamics operations. Let’s explore some of them.

# Check available operators in braintaichi
print("Available modules in braintaichi:")
print([attr for attr in dir(bti) if not attr.startswith('_')])
Available modules in braintaichi:
['XLACustomOp', 'coo_to_csr', 'coomv', 'cpu_ops', 'csr_to_coo', 'csr_to_dense', 'csrmm', 'csrmv', 'defjvp', 'event_csrmm', 'event_csrmv', 'get_homo_weight_matrix', 'get_normal_weight_matrix', 'get_uniform_weight_matrix', 'jitc_event_mv_prob_homo', 'jitc_event_mv_prob_normal', 'jitc_event_mv_prob_uniform', 'jitc_mv_prob_homo', 'jitc_mv_prob_normal', 'jitc_mv_prob_uniform', 'rand', 'register_general_batching']

Performance Tips#

Here are some key tips for optimizing your custom operators:

  1. Parallelize outer loops: Taichi automatically parallelizes the outermost for-loops

  2. Use ti.loop_config(serialize=True): When you need sequential execution or want to use break statements

  3. Choose appropriate data types: Use ti.f32 for single precision, ti.f64 for double precision

  4. Avoid Python objects inside kernels: Use Taichi native types only

  5. Batch operations: Process multiple operations together to reduce overhead

Integration with JAX#

One of the powerful features of braintaichi is seamless integration with JAX, enabling automatic differentiation and JIT compilation.

# Example: Using braintaichi operator in a JAX JIT-compiled function
@jax.jit
def neural_network_step(weights_data, weights_indices, weights_indptr, inputs):
    """Simulate one step of a spiking neural network"""
    n = inputs.shape[0]
    # Apply synaptic weights
    synaptic_input = csr_matvec_op(
        weights_data,
        weights_indices,
        weights_indptr,
        inputs,
        outs=[jax.ShapeDtypeStruct((n,), dtype=jnp.float32)]
    )
    return synaptic_input

# Test the JIT-compiled function
n = 100
sparse_mat = csr_matrix((np.random.rand(n, n) < 0.1).astype(float) * np.random.rand(n, n))
inputs = jnp.array(np.random.rand(n), dtype=jnp.float32)

result = neural_network_step(
    jnp.array(sparse_mat.data, dtype=jnp.float32),
    jnp.array(sparse_mat.indices, dtype=jnp.int32),
    jnp.array(sparse_mat.indptr, dtype=jnp.int32),
    inputs
)

print("Neural network output:", result[0][:5])
Neural network output: [4.365637  1.9171182 1.9304688 2.2530348 2.110351 ]

Next Steps#

Now that you’ve learned the basics, you can:

  1. Read the braintaichi_intro.ipynb for detailed kernel registration interfaces

  2. Explore the complete_example.ipynb for more complex use cases

  3. Check out the advanced_optimization.ipynb for performance optimization techniques

  4. Visit the API Documentation for detailed reference

For more examples, check the source code: