Source code for braintaichi._primitive._batch_utils

# Copyright 2024- BrainPy Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

# -*- coding: utf-8 -*-

from functools import partial

import jax.numpy as jnp
from jax import lax
from jax.interpreters import batching
from jax.tree_util import tree_flatten, tree_unflatten

__all__ = [
    'register_general_batching',
]


def _general_batching_rule(prim, args, axes, **kwargs):
    batch_axes, batch_args, non_batch_args = [], {}, {}
    for ax_i, ax in enumerate(axes):
        if ax is None:
            non_batch_args[f'ax{ax_i}'] = args[ax_i]
        else:
            batch_args[f'ax{ax_i}'] = args[ax_i] if ax == 0 else jnp.moveaxis(args[ax_i], ax, 0)
            batch_axes.append(ax_i)

    def f(_, x):
        pars = tuple([(x[f'ax{i}'] if i in batch_axes else non_batch_args[f'ax{i}'])
                      for i in range(len(axes))])
        return 0, prim.bind(*pars, **kwargs)

    _, outs = lax.scan(f, 0, batch_args)
    out_vals, out_tree = tree_flatten(outs)
    out_dim = tree_unflatten(out_tree, (0,) * len(out_vals))
    return outs, out_dim


[docs] def register_general_batching(prim): batching.primitive_batchers[prim] = partial(_general_batching_rule, prim)
def _shape_to_layout(shape): return tuple(range(len(shape) - 1, -1, -1))