XLACustomOp#
- class braintaichi.XLACustomOp(cpu_kernel=None, gpu_kernel=None, batching_translation=None, jvp_translation=None, transpose_translation=None, name=None)#
Creating a XLA custom call operator.
- Parameters:
cpu_kernel (
Optional[Callable]) – Callable. The function defines the computation on CPU backend.gpu_kernel (
Union[Callable,str,None]) – Callable. The function defines the computation on GPU backend.batching_translation (
Optional[Callable]) – Callable. The batching translation rule of JAX.jvp_translation (
Optional[Callable]) – Callable. The JVP translation rule of JAX.transpose_translation (
Optional[Callable]) – Callable. The transpose translation rule of JAX.
- def_abstract_eval(fun)[source]#
Define the abstract evaluation function.
- Parameters:
fun – The abstract evaluation function.
- def_mlir_lowering(platform, fun)[source]#
Define the MLIR lowering rule.
- Parameters:
platform – str. The computing platform.
fun – The lowering rule.