defjvp#
- class braintaichi.defjvp(primitive, *jvp_rules)[source]#
Define JVP rules for any JAX primitive.
This function is similar to
jax.interpreters.ad.defjvp. However, the JAX one only supports primitive withmultiple_results=False.brainpy.math.defjvpenables to define the independent JVP rule for each input parameter no mattermultiple_results=False/True.For examples, please see
test_ad_support.py.- Parameters:
primitive – Primitive, XLACustomOp.
*jvp_rules – The JVP translation rule for each primal.