1
1
"""Test "unspecified" behavior which we cannot easily test in the Array API test suite.
2
2
"""
3
+ import itertools
4
+
3
5
import pytest
4
6
import torch
5
7
@@ -51,7 +53,10 @@ def test_two_args(self):
51
53
def test_multi_arg (self ):
52
54
torch .set_default_dtype (torch .float32 )
53
55
54
- args = [1 , 2 , 3j , xp .arange (3 ), 4 , 5 , 6 ]
56
+ args = [1. , 5 , 3 , torch .asarray ([3 ], dtype = torch .float16 ), 5 , 6 , 1. ]
57
+ assert xp .result_type (* args ) == torch .float16
58
+
59
+ args = [1 , 2 , 3j , xp .arange (3 , dtype = xp .float32 ), 4 , 5 , 6 ]
55
60
assert xp .result_type (* args ) == xp .complex64
56
61
57
62
args = [1 , 2 , 3j , xp .float64 , 4 , 5 , 6 ]
@@ -60,5 +65,10 @@ def test_multi_arg(self):
60
65
args = [1 , 2 , 3j , xp .float64 , 4 , xp .asarray (3 , dtype = xp .int16 ), 5 , 6 , False ]
61
66
assert xp .result_type (* args ) == xp .complex128
62
67
68
+ i64 = xp .ones (1 , dtype = xp .int64 )
69
+ f16 = xp .ones (1 , dtype = xp .float16 )
70
+ for i in itertools .permutations ([i64 , f16 , 1.0 , 1.0 ]):
71
+ assert xp .result_type (* i ) == xp .float16 , f"{ i } "
72
+
63
73
with pytest .raises (ValueError ):
64
74
xp .result_type (1 , 2 , 3 , 4 )
0 commit comments