Complete Example#
Here is a complete example showing how to implement a simple operator using the taichi custom operator:
import jax
import jax.numpy as jnp
import taichi as ti
import braintaichi as bti
import brainstate as bst
@ti.func
def get_weight(weight: ti.types.ndarray()) -> ti.f32:
return weight[None]
@ti.func
def update_output(out: ti.types.ndarray(), index: ti.i32, weight_val: ti.f32):
out[index] += weight_val
@ti.kernel
def event_ell_cpu(indices: ti.types.ndarray(),
vector: ti.types.ndarray(),
weight: ti.types.ndarray(),
out: ti.types.ndarray()):
weight_val = get_weight(weight)
num_rows, num_cols = indices.shape
ti.loop_config(serialize=True)
for i in range(num_rows):
if vector[i]:
for j in range(num_cols):
update_output(out, indices[i, j], weight_val)
@ti.kernel
def event_ell_gpu(indices: ti.types.ndarray(),
vector: ti.types.ndarray(),
weight: ti.types.ndarray(),
out: ti.types.ndarray()):
weight_val = get_weight(weight)
num_rows, num_cols = indices.shape
for i in range(num_rows):
if vector[i]:
for j in range(num_cols):
update_output(out, indices[i, j], weight_val)
prim = bti.XLACustomOp(cpu_kernel=event_ell_cpu, gpu_kernel=event_ell_gpu)
def test_taichi_op_register():
s = 1000
indices = bst.random.randint(0, s, (s, 1000))
vector = bst.random.rand(s) < 0.1
out = prim(indices, vector, 1.0, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)])
out = prim(indices, vector, 1.0, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)])
print(out)
test_taichi_op_register()
[Taichi] version 1.7.2, llvm 15.0.1, commit 0131dce9, win, python 3.10.13
[Array([122., 106., 119., 123., 102., 125., 116., 113., 121., 106., 118.,
111., 113., 115., 114., 118., 119., 109., 112., 105., 113., 121.,
122., 113., 108., 116., 104., 89., 108., 111., 110., 131., 119.,
125., 117., 98., 107., 118., 112., 121., 121., 119., 119., 122.,
102., 126., 109., 116., 128., 113., 110., 119., 109., 110., 101.,
128., 124., 112., 110., 132., 98., 109., 111., 115., 134., 126.,
125., 102., 133., 118., 112., 123., 123., 113., 119., 115., 123.,
121., 128., 113., 123., 129., 108., 132., 109., 115., 81., 125.,
111., 132., 116., 127., 102., 126., 109., 123., 123., 106., 112.,
131., 106., 117., 107., 112., 133., 130., 110., 109., 142., 101.,
120., 111., 123., 126., 130., 139., 116., 128., 101., 118., 120.,
123., 104., 119., 126., 113., 111., 113., 102., 110., 120., 133.,
134., 111., 101., 120., 116., 106., 111., 105., 105., 112., 111.,
103., 138., 115., 111., 97., 109., 108., 111., 114., 103., 121.,
111., 108., 119., 115., 102., 105., 125., 110., 118., 117., 120.,
111., 108., 128., 105., 96., 111., 113., 129., 114., 117., 117.,
113., 104., 110., 112., 117., 108., 111., 117., 116., 98., 137.,
110., 100., 123., 124., 112., 127., 115., 98., 120., 99., 107.,
111., 116., 122., 113., 132., 113., 118., 102., 124., 97., 125.,
115., 99., 100., 147., 116., 120., 92., 108., 122., 101., 125.,
135., 124., 121., 112., 107., 125., 112., 108., 112., 121., 122.,
120., 104., 123., 108., 85., 111., 105., 120., 105., 101., 118.,
90., 104., 116., 130., 88., 97., 124., 131., 127., 130., 101.,
124., 104., 106., 114., 101., 121., 95., 108., 141., 117., 108.,
123., 128., 129., 118., 118., 122., 102., 129., 123., 114., 116.,
133., 124., 113., 124., 102., 119., 122., 110., 86., 113., 138.,
123., 125., 103., 124., 133., 140., 109., 107., 128., 110., 103.,
102., 122., 115., 97., 107., 115., 128., 120., 114., 108., 123.,
120., 125., 132., 105., 117., 106., 109., 122., 124., 128., 118.,
113., 126., 128., 118., 104., 132., 114., 118., 114., 108., 128.,
99., 94., 117., 120., 127., 105., 119., 104., 122., 121., 109.,
117., 121., 123., 112., 145., 137., 107., 115., 129., 103., 132.,
111., 131., 129., 124., 94., 107., 128., 122., 109., 121., 116.,
116., 111., 127., 122., 138., 111., 126., 112., 116., 116., 102.,
138., 120., 123., 120., 106., 115., 117., 129., 127., 97., 112.,
106., 110., 114., 107., 117., 120., 85., 115., 119., 117., 106.,
123., 117., 104., 105., 123., 121., 130., 112., 125., 118., 96.,
127., 127., 127., 112., 103., 118., 114., 125., 126., 119., 110.,
122., 111., 109., 105., 115., 123., 108., 89., 115., 111., 125.,
116., 113., 128., 112., 103., 127., 102., 122., 102., 134., 125.,
105., 100., 121., 104., 127., 104., 126., 96., 109., 111., 124.,
104., 115., 110., 103., 122., 119., 123., 116., 131., 113., 108.,
123., 134., 118., 113., 101., 113., 140., 128., 123., 116., 125.,
133., 106., 124., 114., 121., 118., 106., 113., 121., 100., 108.,
105., 107., 141., 115., 120., 124., 103., 96., 125., 113., 113.,
112., 100., 108., 107., 109., 119., 110., 101., 132., 106., 136.,
97., 123., 114., 95., 109., 119., 114., 129., 103., 122., 106.,
117., 114., 122., 117., 90., 111., 117., 118., 123., 114., 101.,
111., 134., 122., 110., 116., 119., 108., 106., 116., 113., 116.,
88., 101., 114., 114., 112., 100., 110., 102., 120., 109., 118.,
104., 102., 106., 115., 104., 119., 108., 95., 122., 120., 111.,
118., 122., 110., 132., 125., 125., 109., 112., 129., 132., 113.,
118., 110., 112., 121., 111., 98., 108., 108., 125., 124., 105.,
110., 112., 121., 109., 109., 121., 115., 107., 112., 122., 126.,
112., 123., 97., 104., 108., 98., 108., 124., 125., 119., 133.,
120., 95., 113., 133., 103., 123., 122., 115., 135., 115., 119.,
126., 119., 122., 105., 104., 127., 108., 114., 110., 124., 136.,
121., 105., 114., 114., 114., 120., 115., 107., 103., 114., 142.,
130., 114., 117., 126., 134., 118., 114., 108., 117., 125., 116.,
115., 97., 121., 127., 129., 128., 139., 125., 106., 119., 119.,
114., 127., 119., 99., 151., 114., 105., 129., 113., 112., 107.,
97., 104., 141., 112., 111., 106., 130., 114., 112., 125., 99.,
103., 88., 125., 103., 125., 103., 96., 106., 122., 104., 90.,
100., 112., 110., 94., 100., 130., 111., 135., 119., 123., 117.,
101., 97., 140., 102., 110., 109., 116., 118., 116., 121., 117.,
113., 93., 113., 100., 109., 121., 107., 116., 131., 106., 120.,
119., 117., 116., 114., 107., 113., 121., 115., 120., 114., 110.,
102., 124., 115., 123., 118., 124., 97., 93., 114., 119., 99.,
119., 114., 112., 139., 106., 121., 96., 125., 105., 108., 116.,
127., 135., 138., 109., 113., 105., 120., 119., 114., 109., 92.,
98., 108., 109., 96., 104., 124., 115., 150., 125., 125., 120.,
127., 135., 124., 121., 131., 118., 131., 126., 103., 127., 116.,
89., 124., 123., 109., 126., 104., 118., 101., 97., 107., 133.,
105., 108., 118., 115., 108., 107., 113., 110., 103., 121., 127.,
102., 130., 118., 130., 111., 97., 126., 118., 120., 138., 102.,
115., 123., 121., 122., 124., 128., 124., 104., 111., 118., 145.,
101., 121., 107., 114., 109., 106., 107., 105., 117., 91., 113.,
103., 120., 138., 111., 113., 128., 139., 101., 117., 107., 83.,
111., 107., 116., 103., 88., 117., 127., 99., 104., 123., 119.,
87., 89., 122., 109., 103., 110., 97., 113., 119., 113., 111.,
126., 100., 126., 112., 112., 107., 111., 123., 127., 129., 115.,
115., 98., 112., 108., 112., 110., 119., 102., 95., 132., 112.,
129., 104., 123., 111., 103., 124., 115., 101., 125., 108., 119.,
143., 120., 113., 131., 116., 100., 102., 126., 111., 108., 103.,
122., 129., 106., 117., 103., 110., 123., 122., 102., 124., 103.,
109., 128., 119., 103., 124., 123., 121., 93., 115., 110., 122.,
91., 102., 116., 116., 116., 118., 107., 123., 134., 110., 107.,
111., 106., 105., 137., 126., 122., 104., 109., 103., 124., 84.,
119., 100., 116., 113., 119., 120., 122., 107., 143., 121., 128.,
124., 142., 118., 129., 116., 126., 109., 106., 100., 123., 103.,
108., 120., 127., 116., 131., 101., 122., 118., 109., 116.], dtype=float32)]
More Examples#
For more examples, please refer to: