XLACustomOp#

class braintaichi.XLACustomOp(cpu_kernel=None, gpu_kernel=None, batching_translation=None, jvp_translation=None, transpose_translation=None, name=None)#

Creating a XLA custom call operator.

Parameters:
  • cpu_kernel (Optional[Callable]) – Callable. The function defines the computation on CPU backend.

  • gpu_kernel (Union[Callable, str, None]) – Callable. The function defines the computation on GPU backend.

  • batching_translation (Optional[Callable]) – Callable. The batching translation rule of JAX.

  • jvp_translation (Optional[Callable]) – Callable. The JVP translation rule of JAX.

  • transpose_translation (Optional[Callable]) – Callable. The transpose translation rule of JAX.

  • name (Optional[str]) – str. The primitive name.

def_abstract_eval(fun)[source]#

Define the abstract evaluation function.

Parameters:

fun – The abstract evaluation function.

def_batching_rule(fun)[source]#

Define the batching rule.

Parameters:

fun – The batching rule.

def_jvp_rule(fun)[source]#

Define the JVP rule.

Parameters:

fun – The JVP rule.

def_mlir_lowering(platform, fun)[source]#

Define the MLIR lowering rule.

Parameters:
  • platform – str. The computing platform.

  • fun – The lowering rule.

def_transpose_rule(fun)[source]#

Define the transpose rule.

Parameters:

fun – The transpose rule.

def_xla_translation(platform, fun)[source]#

Define the XLA translation rule.

Parameters:
  • platform – str. The computing platform.

  • fun – The XLA translation rule.

defjvp(*jvp_rules)[source]#

Define the JVP rule. Similar to jax.interpreters.ad.defjvp, but supports the Primitive with multiple results.

Parameters:

jvp_rules – The JVP rules.