1717from torch .distributed .tensor ._utils import _compute_local_shape_and_global_offset
1818from torchstore .utils import get_local_tensor , spawn_actors
1919
20- from .utils import DTensorActor , main , transport_plus_strategy_params
20+ from .utils import DTensorActor , main , set_transport_type , transport_plus_strategy_params
2121
2222logger = getLogger (__name__ )
2323
3535@pytest .mark .asyncio
3636async def test_1d_resharding (
3737 strategy_params ,
38- use_rdma ,
38+ transport_type ,
3939 put_mesh_shape ,
4040 get_mesh_shape ,
4141 put_sharding_dim ,
@@ -50,13 +50,13 @@ async def test_1d_resharding(
5050 get_mesh_shape = get_mesh_shape ,
5151 get_placements = [Shard (get_sharding_dim )],
5252 strategy = strategy ,
53- use_rdma = use_rdma ,
53+ transport_type = transport_type ,
5454 )
5555
5656
5757@pytest .mark .parametrize (* transport_plus_strategy_params ())
5858@pytest .mark .asyncio
59- async def test_2d_to_2d_resharding (strategy_params , use_rdma ):
59+ async def test_2d_to_2d_resharding (strategy_params , transport_type ):
6060 _ , strategy = strategy_params
6161
6262 put_mesh_shape = get_mesh_shape = (2 , 2 )
@@ -69,13 +69,13 @@ async def test_2d_to_2d_resharding(strategy_params, use_rdma):
6969 get_mesh_shape = get_mesh_shape ,
7070 get_placements = [Shard (dim ) for dim in get_sharding_dims ],
7171 strategy = strategy ,
72- use_rdma = use_rdma ,
72+ transport_type = transport_type ,
7373 )
7474
7575
7676@pytest .mark .parametrize (* transport_plus_strategy_params ())
7777@pytest .mark .asyncio
78- async def test_1d_to_2d_resharding (strategy_params , use_rdma ):
78+ async def test_1d_to_2d_resharding (strategy_params , transport_type ):
7979 _ , strategy = strategy_params
8080
8181 put_mesh_shape = (4 ,)
@@ -89,13 +89,13 @@ async def test_1d_to_2d_resharding(strategy_params, use_rdma):
8989 get_mesh_shape = get_mesh_shape ,
9090 get_placements = [Shard (dim ) for dim in get_sharding_dims ],
9191 strategy = strategy ,
92- use_rdma = use_rdma ,
92+ transport_type = transport_type ,
9393 )
9494
9595
9696@pytest .mark .parametrize (* transport_plus_strategy_params ())
9797@pytest .mark .asyncio
98- async def test_2d_to_1d_resharding (strategy_params , use_rdma ):
98+ async def test_2d_to_1d_resharding (strategy_params , transport_type ):
9999 _ , strategy = strategy_params
100100
101101 put_mesh_shape = (2 , 2 )
@@ -109,13 +109,13 @@ async def test_2d_to_1d_resharding(strategy_params, use_rdma):
109109 get_mesh_shape = get_mesh_shape ,
110110 get_placements = [Shard (dim ) for dim in get_sharding_dims ],
111111 strategy = strategy ,
112- use_rdma = use_rdma ,
112+ transport_type = transport_type ,
113113 )
114114
115115
116116@pytest .mark .parametrize (* transport_plus_strategy_params ())
117117@pytest .mark .asyncio
118- async def test_data_parallel (strategy_params , use_rdma ):
118+ async def test_data_parallel (strategy_params , transport_type ):
119119 _ , strategy = strategy_params
120120
121121 # # 1d
@@ -128,7 +128,7 @@ async def test_data_parallel(strategy_params, use_rdma):
128128 get_mesh_shape = get_mesh_shape ,
129129 get_placements = placements ,
130130 strategy = strategy ,
131- use_rdma = use_rdma ,
131+ transport_type = transport_type ,
132132 )
133133
134134 # 2d -> 1d
@@ -143,7 +143,7 @@ async def test_data_parallel(strategy_params, use_rdma):
143143 get_mesh_shape = get_mesh_shape ,
144144 get_placements = [Shard (1 )],
145145 strategy = strategy ,
146- use_rdma = use_rdma ,
146+ transport_type = transport_type ,
147147 )
148148
149149
@@ -153,7 +153,7 @@ async def _test_resharding(
153153 get_mesh_shape : Tuple [int ],
154154 get_placements : List [Union [Replicate , Shard ]],
155155 strategy : ts .TorchStoreStrategy ,
156- use_rdma : bool ,
156+ transport_type : str ,
157157):
158158 """Given a "put" mesh shape and a "get" mesh shape.
159159 1. Create separate worlds for each mesh shape, running on different devices /PGs.
@@ -177,7 +177,7 @@ async def _test_resharding(
177177
178178 # Rank0: dtensor._local_tensor == [0,1], Rank1: dtensor._local_tensor == [2,3]
179179 """
180- os . environ [ "TORCHSTORE_RDMA_ENABLED" ] = "1" if use_rdma else "0"
180+ set_transport_type ( transport_type )
181181
182182 put_world_size = math .prod (put_mesh_shape )
183183 get_world_size = math .prod (get_mesh_shape )
0 commit comments