BrainTaichi Introduction#
This tutorial provides a comprehensive guide on how to develop custom operators using BrainTaichi.
Kernel Registration Interface#
Brain dynamics is sparse and event-driven, however, proprietary operators for brain dynamics are not well abstracted and summarized. As a result, we are often faced with the need to customize operators. In this tutorial, we will explore how to customize brain dynamics operators using BrainTaichi.
Start by importing the relevant Python package.
import jax
import jax.numpy as jnp
import taichi as ti
import braintaichi as bti
Basic Structure of Custom Operators#
Taichi uses Python functions and decorators to define custom operators. Here is a basic structure of a custom operator:
@ti.kernel
def my_kernel(arg1: ti.types.ndarray(), arg2: ti.types.ndarray()):
# Internal logic of the operator
...
The @ti.kernel decorator tells Taichi that this is a function that requires special compilation.
Defining Helper Functions#
When defining complex custom operators, you can use the @ti.func decorator to define helper functions. These functions can be called inside the kernel function:
@ti.func
def helper_func(x: ti.f32) -> ti.f32:
# Auxiliary computation
return x * 2
@ti.kernel
def my_kernel(arg: ti.types.ndarray()):
for i in ti.ndrange(arg.shape[0]):
arg[i] *= helper_func(arg[i])
Example: Custom Event Processing Operator#
The following example demonstrates how to customize an event processing operator:
@ti.func
def get_weight(weight: ti.types.ndarray()) -> ti.f32:
return weight[None]
@ti.func
def update_output(out: ti.types.ndarray(), index: ti.i32, weight_val: ti.f32):
out[index] += weight_val
@ti.kernel
def event_ell_cpu(indices: ti.types.ndarray(),
vector: ti.types.ndarray(),
weight: ti.types.ndarray(),
out: ti.types.ndarray()):
weight_val = get_weight(weight)
num_rows, num_cols = indices.shape
ti.loop_config(serialize=True)
for i in range(num_rows):
if vector[i]:
for j in range(num_cols):
update_output(out, indices[i, j], weight_val)
In the declaration of parameters, the last few parameters need to be output parameters so that Taichi can compile correctly. This operator event_ell_cpu receives indices, vectors, weights, and output arrays, and updates the output arrays according to the provided logic.
Registering and Using Custom Operators#
After defining a custom operator, it can be registered into a specific framework and used where needed.
BrainTaichi provides a simple and flexible interface for registering custom operators – XLACustomOp. When registering, you can specify cpu_kernel and gpu_kernel, so the operator can run on different devices. Specify the outs parameter when calling, using jax.ShapeDtypeStruct to define the shape and data type of the output.
Note: Maintain the order of the operator’s declared parameters consistent with the order when calling.
# Taichi operator registration
prim = bti.XLACustomOp(cpu_kernel=event_ell_cpu, gpu_kernel=event_ell_gpu)
# Using the operator
def test_taichi_op():
# Create input data
# ...
# Call the custom operator
out = prim(indices, vector, weight, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)])
# Output the result
print(out)
Basic Taichi Concepts#
Taichi is a domain-specific language (DSL) designed to simplify the development of high-performance visual computing and physics simulation algorithms, particularly for computer graphics researchers. Here are some of the basic concepts of Taichi based on the provided introduction:
Embedded in Python#
Taichi is embedded within Python, allowing developers to leverage the simplicity and flexibility of Python while benefiting from the performance of native GPU or CPU instructions. This means that if you are familiar with Python, you can quickly start using Taichi without learning a completely new language.
Just-in-Time (JIT) Compilation#
Taichi uses modern JIT compilation frameworks like LLVM and SPIR-V to translate Python code into native GPU or CPU instructions. This approach ensures that the code runs efficiently both during development and at runtime.
Imperative Programming Paradigm#
Unlike many other DSLs that focus on specific computing patterns, Taichi adopts an imperative programming paradigm. This provides greater flexibility and allows developers to write complex computations in a single kernel, which Taichi refers to as a “mega-kernel.”
Optimizations#
Taichi employs various compiler optimizations such as common subexpression elimination, dead code elimination, and control flow graph analysis. These optimizations are backend-neutral, thanks to Taichi’s own intermediate representation (IR) layer.
Community and Backend Support#
Taichi has a strong and dedicated community that has contributed to the development of various backends, including Vulkan, OpenGL, and DirectX. This wide range of backend support enhances Taichi’s portability and usability across different platforms.
Kernel Optimization in Taichi#
Taichi kernels automatically parallelize for-loops in the outermost scope. Our compiler sets the settings automatically to best explore the target architecture. Nonetheless, for Ninjas seeking the final few percent of speed, we provide several APIs to allow developers to fine-tune their programs. Specifying a proper block_dim is key.
You can use ti.loop_config to set the loop directives for the next for loop. Available directives are:
parallelize: Sets the number of threads to use on CPU
block_dim: Sets the number of threads in a block on GPU
serialize: If you set serialize to True, the for loop will run serially, and you can write break statements inside it (Only applies on range/ndrange fors). Equals to setting parallelize to 1.
@ti.kernel
def break_in_serial_for() -> ti.i32:
a = 0
ti.loop_config(serialize=True)
for i in range(100): # This loop runs serially
a += i
if i == 10:
break
return a
break_in_serial_for() # returns 55
n = 128
val = ti.field(ti.i32, shape=n)
@ti.kernel
def fill():
ti.loop_config(parallelize=8, block_dim=16)
# If the kernel is run on the CPU backend, 8 threads will be used to run it
# If the kernel is run on the CUDA backend, each block will have 16 threads.
for i in range(n):
val[i] = i