Complete Example

Complete Example#

Here is a complete example showing how to implement a simple operator using the taichi custom operator:

import jax
import jax.numpy as jnp
import taichi as ti

import braintaichi as bti
import brainstate as bst


@ti.func
def get_weight(weight: ti.types.ndarray()) -> ti.f32:
  return weight[None]


@ti.func
def update_output(out: ti.types.ndarray(), index: ti.i32, weight_val: ti.f32):
  out[index] += weight_val


@ti.kernel
def event_ell_cpu(indices: ti.types.ndarray(),
                  vector: ti.types.ndarray(),
                  weight: ti.types.ndarray(),
                  out: ti.types.ndarray()):
  weight_val = get_weight(weight)
  num_rows, num_cols = indices.shape
  ti.loop_config(serialize=True)
  for i in range(num_rows):
    if vector[i]:
      for j in range(num_cols):
        update_output(out, indices[i, j], weight_val)

@ti.kernel
def event_ell_gpu(indices: ti.types.ndarray(),
                  vector: ti.types.ndarray(), 
                  weight: ti.types.ndarray(), 
                  out: ti.types.ndarray()):
  weight_val = get_weight(weight)
  num_rows, num_cols = indices.shape
  for i in range(num_rows):
    if vector[i]:
      for j in range(num_cols):
        update_output(out, indices[i, j], weight_val)

prim = bti.XLACustomOp(cpu_kernel=event_ell_cpu, gpu_kernel=event_ell_gpu)


def test_taichi_op_register():
  s = 1000
  indices = bst.random.randint(0, s, (s, 1000))
  vector = bst.random.rand(s) < 0.1

  out = prim(indices, vector, 1.0, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)])

  out = prim(indices, vector, 1.0, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)])

  print(out)

test_taichi_op_register()
[Taichi] version 1.7.2, llvm 15.0.1, commit 0131dce9, win, python 3.10.13
[Array([122., 106., 119., 123., 102., 125., 116., 113., 121., 106., 118.,
       111., 113., 115., 114., 118., 119., 109., 112., 105., 113., 121.,
       122., 113., 108., 116., 104.,  89., 108., 111., 110., 131., 119.,
       125., 117.,  98., 107., 118., 112., 121., 121., 119., 119., 122.,
       102., 126., 109., 116., 128., 113., 110., 119., 109., 110., 101.,
       128., 124., 112., 110., 132.,  98., 109., 111., 115., 134., 126.,
       125., 102., 133., 118., 112., 123., 123., 113., 119., 115., 123.,
       121., 128., 113., 123., 129., 108., 132., 109., 115.,  81., 125.,
       111., 132., 116., 127., 102., 126., 109., 123., 123., 106., 112.,
       131., 106., 117., 107., 112., 133., 130., 110., 109., 142., 101.,
       120., 111., 123., 126., 130., 139., 116., 128., 101., 118., 120.,
       123., 104., 119., 126., 113., 111., 113., 102., 110., 120., 133.,
       134., 111., 101., 120., 116., 106., 111., 105., 105., 112., 111.,
       103., 138., 115., 111.,  97., 109., 108., 111., 114., 103., 121.,
       111., 108., 119., 115., 102., 105., 125., 110., 118., 117., 120.,
       111., 108., 128., 105.,  96., 111., 113., 129., 114., 117., 117.,
       113., 104., 110., 112., 117., 108., 111., 117., 116.,  98., 137.,
       110., 100., 123., 124., 112., 127., 115.,  98., 120.,  99., 107.,
       111., 116., 122., 113., 132., 113., 118., 102., 124.,  97., 125.,
       115.,  99., 100., 147., 116., 120.,  92., 108., 122., 101., 125.,
       135., 124., 121., 112., 107., 125., 112., 108., 112., 121., 122.,
       120., 104., 123., 108.,  85., 111., 105., 120., 105., 101., 118.,
        90., 104., 116., 130.,  88.,  97., 124., 131., 127., 130., 101.,
       124., 104., 106., 114., 101., 121.,  95., 108., 141., 117., 108.,
       123., 128., 129., 118., 118., 122., 102., 129., 123., 114., 116.,
       133., 124., 113., 124., 102., 119., 122., 110.,  86., 113., 138.,
       123., 125., 103., 124., 133., 140., 109., 107., 128., 110., 103.,
       102., 122., 115.,  97., 107., 115., 128., 120., 114., 108., 123.,
       120., 125., 132., 105., 117., 106., 109., 122., 124., 128., 118.,
       113., 126., 128., 118., 104., 132., 114., 118., 114., 108., 128.,
        99.,  94., 117., 120., 127., 105., 119., 104., 122., 121., 109.,
       117., 121., 123., 112., 145., 137., 107., 115., 129., 103., 132.,
       111., 131., 129., 124.,  94., 107., 128., 122., 109., 121., 116.,
       116., 111., 127., 122., 138., 111., 126., 112., 116., 116., 102.,
       138., 120., 123., 120., 106., 115., 117., 129., 127.,  97., 112.,
       106., 110., 114., 107., 117., 120.,  85., 115., 119., 117., 106.,
       123., 117., 104., 105., 123., 121., 130., 112., 125., 118.,  96.,
       127., 127., 127., 112., 103., 118., 114., 125., 126., 119., 110.,
       122., 111., 109., 105., 115., 123., 108.,  89., 115., 111., 125.,
       116., 113., 128., 112., 103., 127., 102., 122., 102., 134., 125.,
       105., 100., 121., 104., 127., 104., 126.,  96., 109., 111., 124.,
       104., 115., 110., 103., 122., 119., 123., 116., 131., 113., 108.,
       123., 134., 118., 113., 101., 113., 140., 128., 123., 116., 125.,
       133., 106., 124., 114., 121., 118., 106., 113., 121., 100., 108.,
       105., 107., 141., 115., 120., 124., 103.,  96., 125., 113., 113.,
       112., 100., 108., 107., 109., 119., 110., 101., 132., 106., 136.,
        97., 123., 114.,  95., 109., 119., 114., 129., 103., 122., 106.,
       117., 114., 122., 117.,  90., 111., 117., 118., 123., 114., 101.,
       111., 134., 122., 110., 116., 119., 108., 106., 116., 113., 116.,
        88., 101., 114., 114., 112., 100., 110., 102., 120., 109., 118.,
       104., 102., 106., 115., 104., 119., 108.,  95., 122., 120., 111.,
       118., 122., 110., 132., 125., 125., 109., 112., 129., 132., 113.,
       118., 110., 112., 121., 111.,  98., 108., 108., 125., 124., 105.,
       110., 112., 121., 109., 109., 121., 115., 107., 112., 122., 126.,
       112., 123.,  97., 104., 108.,  98., 108., 124., 125., 119., 133.,
       120.,  95., 113., 133., 103., 123., 122., 115., 135., 115., 119.,
       126., 119., 122., 105., 104., 127., 108., 114., 110., 124., 136.,
       121., 105., 114., 114., 114., 120., 115., 107., 103., 114., 142.,
       130., 114., 117., 126., 134., 118., 114., 108., 117., 125., 116.,
       115.,  97., 121., 127., 129., 128., 139., 125., 106., 119., 119.,
       114., 127., 119.,  99., 151., 114., 105., 129., 113., 112., 107.,
        97., 104., 141., 112., 111., 106., 130., 114., 112., 125.,  99.,
       103.,  88., 125., 103., 125., 103.,  96., 106., 122., 104.,  90.,
       100., 112., 110.,  94., 100., 130., 111., 135., 119., 123., 117.,
       101.,  97., 140., 102., 110., 109., 116., 118., 116., 121., 117.,
       113.,  93., 113., 100., 109., 121., 107., 116., 131., 106., 120.,
       119., 117., 116., 114., 107., 113., 121., 115., 120., 114., 110.,
       102., 124., 115., 123., 118., 124.,  97.,  93., 114., 119.,  99.,
       119., 114., 112., 139., 106., 121.,  96., 125., 105., 108., 116.,
       127., 135., 138., 109., 113., 105., 120., 119., 114., 109.,  92.,
        98., 108., 109.,  96., 104., 124., 115., 150., 125., 125., 120.,
       127., 135., 124., 121., 131., 118., 131., 126., 103., 127., 116.,
        89., 124., 123., 109., 126., 104., 118., 101.,  97., 107., 133.,
       105., 108., 118., 115., 108., 107., 113., 110., 103., 121., 127.,
       102., 130., 118., 130., 111.,  97., 126., 118., 120., 138., 102.,
       115., 123., 121., 122., 124., 128., 124., 104., 111., 118., 145.,
       101., 121., 107., 114., 109., 106., 107., 105., 117.,  91., 113.,
       103., 120., 138., 111., 113., 128., 139., 101., 117., 107.,  83.,
       111., 107., 116., 103.,  88., 117., 127.,  99., 104., 123., 119.,
        87.,  89., 122., 109., 103., 110.,  97., 113., 119., 113., 111.,
       126., 100., 126., 112., 112., 107., 111., 123., 127., 129., 115.,
       115.,  98., 112., 108., 112., 110., 119., 102.,  95., 132., 112.,
       129., 104., 123., 111., 103., 124., 115., 101., 125., 108., 119.,
       143., 120., 113., 131., 116., 100., 102., 126., 111., 108., 103.,
       122., 129., 106., 117., 103., 110., 123., 122., 102., 124., 103.,
       109., 128., 119., 103., 124., 123., 121.,  93., 115., 110., 122.,
        91., 102., 116., 116., 116., 118., 107., 123., 134., 110., 107.,
       111., 106., 105., 137., 126., 122., 104., 109., 103., 124.,  84.,
       119., 100., 116., 113., 119., 120., 122., 107., 143., 121., 128.,
       124., 142., 118., 129., 116., 126., 109., 106., 100., 123., 103.,
       108., 120., 127., 116., 131., 101., 122., 118., 109., 116.],      dtype=float32)]

More Examples#

For more examples, please refer to: