|  | 
| 25 | 25 |     stack, | 
| 26 | 26 |     switch, | 
| 27 | 27 | ) | 
|  | 28 | +from pytensor.tensor.blockwise import Blockwise | 
| 28 | 29 | from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise, scalar_elemwise | 
| 29 | 30 | from pytensor.tensor.shape import shape, specify_broadcastable | 
| 30 | 31 | from pytensor.tensor.type import ( | 
| 31 | 32 |     DenseTensorType, | 
| 32 |  | -    TensorType, | 
| 33 | 33 |     complex_dtypes, | 
| 34 | 34 |     continuous_dtypes, | 
| 35 | 35 |     discrete_dtypes, | 
| @@ -2868,93 +2868,7 @@ def logsumexp(x, axis=None, keepdims=False): | 
| 2868 | 2868 |     return log(sum(exp(x), axis=axis, keepdims=keepdims)) | 
| 2869 | 2869 | 
 | 
| 2870 | 2870 | 
 | 
| 2871 |  | -class MatMul(Op): | 
| 2872 |  | -    __props__ = ("dtype",) | 
| 2873 |  | - | 
| 2874 |  | -    def __init__(self, dtype=None): | 
| 2875 |  | -        self.dtype = dtype | 
| 2876 |  | - | 
| 2877 |  | -    @classmethod | 
| 2878 |  | -    def _get_output_shape(cls, x1, x2, shapes, validate=False): | 
| 2879 |  | -        x1_shape, x2_shape = shapes | 
| 2880 |  | - | 
| 2881 |  | -        if x1.ndim == 1 and x2.ndim == 1: | 
| 2882 |  | -            if validate and x1_shape[0] != x2_shape[0]: | 
| 2883 |  | -                raise ValueError("1d inputs must have the same length.") | 
| 2884 |  | -            return () | 
| 2885 |  | -        elif x1.ndim == 1 and x2.ndim > 1: | 
| 2886 |  | -            if validate and x1_shape[0] != x2_shape[-2]: | 
| 2887 |  | -                raise ValueError( | 
| 2888 |  | -                    "length of input 1 must be equal the length " | 
| 2889 |  | -                    "of the 2nd-last dimension of input 2" | 
| 2890 |  | -                ) | 
| 2891 |  | -            return x2_shape[:-2] + x2_shape[-1:] | 
| 2892 |  | -        elif x1.ndim > 1 and x2.ndim == 1: | 
| 2893 |  | -            if validate and x1_shape[-1] != x2_shape[0]: | 
| 2894 |  | -                raise ValueError( | 
| 2895 |  | -                    "length of input 2 must be equal the length " | 
| 2896 |  | -                    "of the last dimension of input 1" | 
| 2897 |  | -                ) | 
| 2898 |  | -            return x1_shape[:-1] | 
| 2899 |  | -        elif x1.ndim == 2 and x2.ndim == 2: | 
| 2900 |  | -            if validate and x1_shape[-1] != x2_shape[0]: | 
| 2901 |  | -                raise ValueError( | 
| 2902 |  | -                    "number of columns of input 1 must be equal to " | 
| 2903 |  | -                    "the number of rows of input 2" | 
| 2904 |  | -                ) | 
| 2905 |  | -            return x1_shape[:-1] + x2_shape[-1:] | 
| 2906 |  | -        elif x1.ndim > 2 and x2.ndim == 2: | 
| 2907 |  | -            if validate and x1_shape[-1] != x2_shape[0]: | 
| 2908 |  | -                raise ValueError( | 
| 2909 |  | -                    "number of rows of input 2 must be equal to " | 
| 2910 |  | -                    "the length of the last dimension of input 1" | 
| 2911 |  | -                ) | 
| 2912 |  | -            return x1_shape[:-2] + x1_shape[-2:-1] + x2_shape[-1:] | 
| 2913 |  | -        elif x1.ndim == 2 and x2.ndim > 2: | 
| 2914 |  | -            if validate and x1_shape[-1] != x2_shape[-2]: | 
| 2915 |  | -                raise ValueError( | 
| 2916 |  | -                    "number of columns of input 1 must be equal " | 
| 2917 |  | -                    "the length of the 2nd-last dimension of input 2" | 
| 2918 |  | -                ) | 
| 2919 |  | -            return x2_shape[:-2] + x1_shape[-2:-1] + x2_shape[-1:] | 
| 2920 |  | -        else: | 
| 2921 |  | -            if validate: | 
| 2922 |  | -                from pytensor.tensor.random.basic import broadcast_shapes | 
| 2923 |  | - | 
| 2924 |  | -                bshape = broadcast_shapes(x1_shape[:-2], x2_shape[:-2]) | 
| 2925 |  | -                if x1_shape[-1] != x2_shape[-2]: | 
| 2926 |  | -                    raise ValueError( | 
| 2927 |  | -                        "length of the last dimension of input 1 must be equal " | 
| 2928 |  | -                        "to the length of the 2nd-last dimension of input 2" | 
| 2929 |  | -                    ) | 
| 2930 |  | -            else: | 
| 2931 |  | -                from pytensor.tensor.extra_ops import broadcast_shape | 
| 2932 |  | - | 
| 2933 |  | -                bshape = broadcast_shape( | 
| 2934 |  | -                    x1_shape[:-2], x2_shape[:-2], arrays_are_shapes=True | 
| 2935 |  | -                ) | 
| 2936 |  | -            return bshape + x1_shape[-2:-1] + x2_shape[-1:] | 
| 2937 |  | - | 
| 2938 |  | -    def make_node(self, a, b): | 
| 2939 |  | -        a = as_tensor_variable(a) | 
| 2940 |  | -        b = as_tensor_variable(b) | 
| 2941 |  | - | 
| 2942 |  | -        if 0 in {a.ndim, b.ndim}: | 
| 2943 |  | -            raise ValueError("inputs to `matmul` cannot be scalar.") | 
| 2944 |  | - | 
| 2945 |  | -        out_shape = self._get_output_shape( | 
| 2946 |  | -            a, b, (a.type.shape, b.type.shape), validate=True | 
| 2947 |  | -        ) | 
| 2948 |  | -        out = TensorType(dtype=self.dtype, shape=out_shape)() | 
| 2949 |  | -        return Apply(self, [a, b], [out]) | 
| 2950 |  | - | 
| 2951 |  | -    def perform(self, node, inputs, outputs): | 
| 2952 |  | -        x1, x2 = inputs | 
| 2953 |  | -        outputs[0][0] = np.matmul(x1, x2, dtype=self.dtype) | 
| 2954 |  | - | 
| 2955 |  | -    def infer_shape(self, fgraph, node, shapes): | 
| 2956 |  | -        x1, x2 = node.inputs | 
| 2957 |  | -        return [self._get_output_shape(x1, x2, shapes)] | 
|  | 2871 | +_matrix_matrix_matmul = Blockwise(_dot, signature="(n,k),(k,m)->(n,m)") | 
| 2958 | 2872 | 
 | 
| 2959 | 2873 | 
 | 
| 2960 | 2874 | def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None): | 
| @@ -2999,7 +2913,23 @@ def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None | 
| 2999 | 2913 |     - Stacks of matrices are broadcast together as if the matrices were elements, | 
| 3000 | 2914 |         respecting the signature ``(n, k), (k, m) -> (n, m)``: | 
| 3001 | 2915 |     """ | 
| 3002 |  | -    return MatMul(dtype=dtype)(x1, x2) | 
|  | 2916 | +    x1 = as_tensor_variable(x1) | 
|  | 2917 | +    x2 = as_tensor_variable(x2) | 
|  | 2918 | +    if x1.type.ndim == 0 or x2.type.ndim == 0: | 
|  | 2919 | +        raise ValueError("matmul operand cannot be scalar") | 
|  | 2920 | +    if x1.type.ndim == 1 and x2.type.ndim == 1: | 
|  | 2921 | +        out = _dot(x1, x2) | 
|  | 2922 | +    elif x1.type.ndim == 1: | 
|  | 2923 | +        out = _matrix_matrix_matmul(x1[None], x2).squeeze(-2) | 
|  | 2924 | +    elif x2.type.ndim == 1: | 
|  | 2925 | +        out = _matrix_matrix_matmul(x1, x2[:, None]).squeeze(-1) | 
|  | 2926 | +    else: | 
|  | 2927 | +        out = _matrix_matrix_matmul(x1, x2) | 
|  | 2928 | + | 
|  | 2929 | +    if dtype is not None: | 
|  | 2930 | +        out = out.astype(dtype) | 
|  | 2931 | + | 
|  | 2932 | +    return out | 
| 3003 | 2933 | 
 | 
| 3004 | 2934 | 
 | 
| 3005 | 2935 | __all__ = [ | 
|  | 
0 commit comments