BrainTaichi Introduction#

This tutorial provides a comprehensive guide on how to develop custom operators using BrainTaichi.

Kernel Registration Interface#

Brain dynamics is sparse and event-driven, however, proprietary operators for brain dynamics are not well abstracted and summarized. As a result, we are often faced with the need to customize operators. In this tutorial, we will explore how to customize brain dynamics operators using BrainTaichi.

Start by importing the relevant Python package.

import jax
import jax.numpy as jnp
import numpy as np
import taichi as ti
import braintaichi as bti

# Initialize Taichi
ti.init(arch=ti.cpu)
print("Taichi initialized successfully")
[Taichi] version 1.7.3, llvm 15.0.1, commit 5ec301be, win, python 3.12.12
[Taichi] Starting on arch=x64
Taichi initialized successfully

Basic Structure of Custom Operators#

Taichi uses Python functions and decorators to define custom operators. Here is a basic structure of a custom operator:

@ti.kernel
def my_kernel(arg1: ti.types.ndarray(ndim=1), arg2: ti.types.ndarray(ndim=1)):
    # Internal logic of the operator
    for i in range(arg1.shape[0]):
        arg2[i] = arg1[i] * 2.0

The @ti.kernel decorator tells Taichi that this is a function that requires special compilation.

Defining Helper Functions#

When defining complex custom operators, you can use the @ti.func decorator to define helper functions. These functions can be called inside the kernel function:

@ti.func
def helper_func(x: ti.f32) -> ti.f32:
    # Auxiliary computation
    return x * 2


@ti.kernel
def my_kernel_with_helper(arg: ti.types.ndarray(ndim=1)):
    for i in range(arg.shape[0]):
        arg[i] = helper_func(arg[i])

Example 1: Using Built-in BrainTaichi Operators#

Before creating custom operators, let’s see how to use BrainTaichi’s built-in operators. BrainTaichi provides optimized sparse matrix operations for brain dynamics simulations:

# Example: Using event-driven CSR matrix-vector multiplication
# This is useful for spike propagation in neural networks

# Create a small sparse connectivity matrix in CSR format
# 5 neurons, sparse connections
num_pre = 5
num_post = 5

# CSR format: indptr, indices, data
# Row 0: connects to [1, 3]
# Row 1: connects to [0, 2, 4]  
# Row 2: connects to [1]
# Row 3: connects to [2, 4]
# Row 4: connects to [0, 3]

indices = jnp.array([1, 3, 0, 2, 4, 1, 2, 4, 0, 3], dtype=jnp.int32)
indptr = jnp.array([0, 2, 5, 6, 8, 10], dtype=jnp.int32)
weight = jnp.array([0.5], dtype=jnp.float32)  # Homogeneous weight

# Events: which pre-synaptic neurons fired
events = jnp.array([1.0, 0.0, 1.0, 0.0, 1.0], dtype=jnp.float32)

print("Sparse connectivity (CSR format):")
print(f"  indices: {indices}")
print(f"  indptr: {indptr}")
print(f"  weight: {weight[0]}")
print(f"\nPre-synaptic spikes: {events}")
print(f"  (Neurons 0, 2, 4 fired)")

# Use BrainTaichi's event CSR matrix-vector multiply
result = bti.event_csrmv(
    weight, 
    indices, 
    indptr, 
    events,
    shape=(num_pre, num_post),
    transpose=False
)

print(f"\nPost-synaptic currents: {result[0]}")
print("Explanation: Each firing neuron sends weight to its targets")
Sparse connectivity (CSR format):
  indices: [1 3 0 2 4 1 2 4 0 3]
  indptr: [ 0  2  5  6  8 10]
  weight: 0.5

Pre-synaptic spikes: [1. 0. 1. 0. 1.]
  (Neurons 0, 2, 4 fired)

Post-synaptic currents: 0.0
Explanation: Each firing neuron sends weight to its targets

Example 2: Custom Event Processing Operator#

Now let’s create a simple custom operator. We’ll build an event-driven sparse operation using the ELL (ELLPACK) format:

# Step 1: Define the Taichi kernels for CPU and GPU

@ti.kernel
def event_ell_cpu(
    indices: ti.types.ndarray(ndim=2),
    events: ti.types.ndarray(ndim=1),
    weight: ti.types.ndarray(ndim=1),
    out: ti.types.ndarray(ndim=1)
):
    """
    Event-driven sparse matrix-vector multiply using ELL format.
    Only processes rows where events[i] > 0.
    """
    w = weight[0]
    num_rows, num_cols = indices.shape
    
    # Serialize on CPU for correctness
    ti.loop_config(serialize=True)
    for i in range(num_rows):
        if events[i] > 0.0:
            # Process all connections for this row
            for j in range(num_cols):
                col_idx = indices[i, j]
                if col_idx >= 0:  # Check for valid index
                    out[col_idx] += w * events[i]


@ti.kernel  
def event_ell_gpu(
    indices: ti.types.ndarray(ndim=2),
    events: ti.types.ndarray(ndim=1),
    weight: ti.types.ndarray(ndim=1),
    out: ti.types.ndarray(ndim=1)
):
    """
    GPU version: parallel across rows
    """
    w = weight[0]
    num_rows, num_cols = indices.shape
    
    for i in range(num_rows):
        if events[i] > 0.0:
            for j in range(num_cols):
                col_idx = indices[i, j]
                if col_idx >= 0:
                    out[col_idx] += w * events[i]

print("Kernels defined successfully")
Kernels defined successfully
# Step 2: Register the operator with BrainTaichi
# This makes it compatible with JAX's JIT, grad, vmap, etc.

event_ell_op = bti.XLACustomOp(
    cpu_kernel=event_ell_cpu,
    gpu_kernel=event_ell_gpu,
    name='event_ell_custom'
)

print("Operator registered successfully")

# Step 3: Create test data in ELL format
num_neurons = 5
max_connections = 3

# ELL format: each row has fixed width (max_connections)
# -1 indicates no connection
ell_indices = [
    [1, 3, -1],    # Neuron 0 connects to [1, 3]
    [0, 2, 4],     # Neuron 1 connects to [0, 2, 4]
    [1, -1, -1],   # Neuron 2 connects to [1]
    [2, 4, -1],    # Neuron 3 connects to [2, 4]
    [0, 3, -1],    # Neuron 4 connects to [0, 3]
]

# Same events as before
events = [1.0, 0.0, 1.0, 0.0, 1.0]  # Neurons 0, 2, 4 fire
weight = [0.5]

print(f"\nELL format connectivity:")
print(f"  indices shape: {len(ell_indices)}x{len(ell_indices[0])}")
print(f"  events: {events}")
print(f"  weight: {weight[0]}")
Operator registered successfully

ELL format connectivity:
  indices shape: 5x3
  events: [1.0, 0.0, 1.0, 0.0, 1.0]
  weight: 0.5

Test the custom operator#

## JAX Integration Features

# One of the key benefits of using BrainTaichi's XLACustomOp is seamless 
# integration with JAX. Your custom operators automatically support:
# - JIT compilation: Speed up execution
# - Automatic differentiation: For gradient-based learning
# - Vectorization (vmap): Process batches efficiently

# Example: JIT compilation for faster execution
@jax.jit
def run_network_jit(indices, events, weight):
    """JIT-compiled function using our custom operator"""
    return event_ell_op(
        indices,
        events, 
        weight,
        outs=[jax.ShapeDtypeStruct((num_neurons,), dtype=jnp.float32)]
    )[0]

# First call compiles the function
result_jit = run_network_jit(
    jnp.array(ell_indices, dtype=jnp.int32),
    jnp.array(events, dtype=jnp.float32),
    jnp.array(weight, dtype=jnp.float32)
)

print("JIT-compiled result:", result_jit)
print("\nThe JIT-compiled version will be faster on subsequent calls!")
JIT-compiled result: [0.5 1.  0.  1.  0. ]

The JIT-compiled version will be faster on subsequent calls!

Basic Taichi Concepts#

Taichi is a domain-specific language (DSL) embedded in Python, designed for high-performance computing. Here are key concepts:

1. Embedded in Python#

Taichi is embedded within Python, letting you leverage Python’s simplicity while getting native GPU/CPU performance. If you know Python, you can start using Taichi immediately.

2. Just-in-Time (JIT) Compilation#

Taichi uses JIT compilation (LLVM, SPIR-V) to translate Python code into native GPU or CPU instructions, ensuring high performance at runtime.

3. Imperative Programming#

Unlike many DSLs, Taichi uses an imperative programming paradigm, giving you flexibility to write complex computations in a single “mega-kernel”.

4. Compiler Optimizations#

Taichi employs various optimizations (common subexpression elimination, dead code elimination, control flow analysis) that are backend-neutral thanks to its own IR layer.

Kernel Optimization in Taichi#

Taichi kernels automatically parallelize outermost for-loops. However, you can fine-tune performance using ti.loop_config:

  • parallelize: Set number of CPU threads

  • block_dim: Set GPU block size

  • serialize: Run loop serially (allows break statements)

Kernel Optimization in Taichi#

Taichi kernels automatically parallelize for-loops in the outermost scope. Our compiler sets the settings automatically to best explore the target architecture. Nonetheless, for Ninjas seeking the final few percent of speed, we provide several APIs to allow developers to fine-tune their programs. Specifying a proper block_dim is key.

You can use ti.loop_config to set the loop directives for the next for loop. Available directives are:

  • parallelize: Sets the number of threads to use on CPU

  • block_dim: Sets the number of threads in a block on GPU

  • serialize: If you set serialize to True, the for loop will run serially, and you can write break statements inside it (Only applies on range/ndrange fors). Equals to setting parallelize to 1.

@ti.kernel
def break_in_serial_for() -> ti.i32:
    a = 0
    ti.loop_config(serialize=True)
    for i in range(100):  # This loop runs serially
        a += i
        if i == 10:
            break
    return a


result = break_in_serial_for()  # returns 55
print(f"Result of break_in_serial_for: {result}")


# Example of using loop_config with ndarray
@ti.kernel
def fill_array(val: ti.types.ndarray(ndim=1)):
    n = val.shape[0]
    ti.loop_config(parallelize=8)
    # If the kernel is run on the CPU backend, 8 threads will be used to run it
    for i in range(n):
        val[i] = i


n = 128
val_array = np.zeros(n, dtype=np.int32)
fill_array(val_array)
print(f"First 10 values: {val_array[:10]}")
Result of break_in_serial_for: 55
First 10 values: [0 1 2 3 4 5 6 7 8 9]