Skip to content

Commit 32f798a

Browse files
amirafzalifacebook-github-bot
authored andcommitted
run unit tests with torchcomms RDMA
Differential Revision: D87329523
1 parent ae29d53 commit 32f798a

File tree

9 files changed

+83
-56
lines changed

9 files changed

+83
-56
lines changed

tests/test_models.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from torchstore.utils import spawn_actors
2121
from transformers import AutoModelForCausalLM
2222

23-
from .utils import main, transport_plus_strategy_params
23+
from .utils import main, set_transport_type, transport_plus_strategy_params
2424

2525
logger = getLogger(__name__)
2626

@@ -120,24 +120,24 @@ async def do_get(self):
120120

121121
@pytest.mark.parametrize(*transport_plus_strategy_params())
122122
@pytest.mark.asyncio
123-
async def test_basic(strategy_params, use_rdma):
123+
async def test_basic(strategy_params, transport_type):
124124
# FSDP
125125
put_mesh_shape = (1,)
126126
get_mesh_shape = (1,)
127-
await _do_test(put_mesh_shape, get_mesh_shape, strategy_params[1], use_rdma)
127+
await _do_test(put_mesh_shape, get_mesh_shape, strategy_params[1], transport_type)
128128

129129

130130
@pytest.mark.parametrize(*transport_plus_strategy_params())
131131
@pytest.mark.asyncio
132-
async def test_resharding(strategy_params, use_rdma):
132+
async def test_resharding(strategy_params, transport_type):
133133
# FSDP
134134
put_mesh_shape = (4,)
135135
get_mesh_shape = (8,)
136-
await _do_test(put_mesh_shape, get_mesh_shape, strategy_params[1], use_rdma)
136+
await _do_test(put_mesh_shape, get_mesh_shape, strategy_params[1], transport_type)
137137

138138

139-
async def _do_test(put_mesh_shape, get_mesh_shape, strategy, use_rdma):
140-
os.environ["TORCHSTORE_RDMA_ENABLED"] = "1" if use_rdma else "0"
139+
async def _do_test(put_mesh_shape, get_mesh_shape, strategy, transport_type):
140+
set_transport_type(transport_type)
141141

142142
ts.init_logging()
143143
logger.info(f"Testing with strategy: {strategy}")

tests/test_resharding_basic.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from torch.distributed.tensor._utils import _compute_local_shape_and_global_offset
1818
from torchstore.utils import get_local_tensor, spawn_actors
1919

20-
from .utils import DTensorActor, main, transport_plus_strategy_params
20+
from .utils import DTensorActor, main, set_transport_type, transport_plus_strategy_params
2121

2222
logger = getLogger(__name__)
2323

@@ -35,7 +35,7 @@
3535
@pytest.mark.asyncio
3636
async def test_1d_resharding(
3737
strategy_params,
38-
use_rdma,
38+
transport_type,
3939
put_mesh_shape,
4040
get_mesh_shape,
4141
put_sharding_dim,
@@ -50,13 +50,13 @@ async def test_1d_resharding(
5050
get_mesh_shape=get_mesh_shape,
5151
get_placements=[Shard(get_sharding_dim)],
5252
strategy=strategy,
53-
use_rdma=use_rdma,
53+
transport_type=transport_type,
5454
)
5555

5656

5757
@pytest.mark.parametrize(*transport_plus_strategy_params())
5858
@pytest.mark.asyncio
59-
async def test_2d_to_2d_resharding(strategy_params, use_rdma):
59+
async def test_2d_to_2d_resharding(strategy_params, transport_type):
6060
_, strategy = strategy_params
6161

6262
put_mesh_shape = get_mesh_shape = (2, 2)
@@ -69,13 +69,13 @@ async def test_2d_to_2d_resharding(strategy_params, use_rdma):
6969
get_mesh_shape=get_mesh_shape,
7070
get_placements=[Shard(dim) for dim in get_sharding_dims],
7171
strategy=strategy,
72-
use_rdma=use_rdma,
72+
transport_type=transport_type,
7373
)
7474

7575

7676
@pytest.mark.parametrize(*transport_plus_strategy_params())
7777
@pytest.mark.asyncio
78-
async def test_1d_to_2d_resharding(strategy_params, use_rdma):
78+
async def test_1d_to_2d_resharding(strategy_params, transport_type):
7979
_, strategy = strategy_params
8080

8181
put_mesh_shape = (4,)
@@ -89,13 +89,13 @@ async def test_1d_to_2d_resharding(strategy_params, use_rdma):
8989
get_mesh_shape=get_mesh_shape,
9090
get_placements=[Shard(dim) for dim in get_sharding_dims],
9191
strategy=strategy,
92-
use_rdma=use_rdma,
92+
transport_type=transport_type,
9393
)
9494

9595

9696
@pytest.mark.parametrize(*transport_plus_strategy_params())
9797
@pytest.mark.asyncio
98-
async def test_2d_to_1d_resharding(strategy_params, use_rdma):
98+
async def test_2d_to_1d_resharding(strategy_params, transport_type):
9999
_, strategy = strategy_params
100100

101101
put_mesh_shape = (2, 2)
@@ -109,13 +109,13 @@ async def test_2d_to_1d_resharding(strategy_params, use_rdma):
109109
get_mesh_shape=get_mesh_shape,
110110
get_placements=[Shard(dim) for dim in get_sharding_dims],
111111
strategy=strategy,
112-
use_rdma=use_rdma,
112+
transport_type=transport_type,
113113
)
114114

115115

116116
@pytest.mark.parametrize(*transport_plus_strategy_params())
117117
@pytest.mark.asyncio
118-
async def test_data_parallel(strategy_params, use_rdma):
118+
async def test_data_parallel(strategy_params, transport_type):
119119
_, strategy = strategy_params
120120

121121
# # 1d
@@ -128,7 +128,7 @@ async def test_data_parallel(strategy_params, use_rdma):
128128
get_mesh_shape=get_mesh_shape,
129129
get_placements=placements,
130130
strategy=strategy,
131-
use_rdma=use_rdma,
131+
transport_type=transport_type,
132132
)
133133

134134
# 2d -> 1d
@@ -143,7 +143,7 @@ async def test_data_parallel(strategy_params, use_rdma):
143143
get_mesh_shape=get_mesh_shape,
144144
get_placements=[Shard(1)],
145145
strategy=strategy,
146-
use_rdma=use_rdma,
146+
transport_type=transport_type,
147147
)
148148

149149

@@ -153,7 +153,7 @@ async def _test_resharding(
153153
get_mesh_shape: Tuple[int],
154154
get_placements: List[Union[Replicate, Shard]],
155155
strategy: ts.TorchStoreStrategy,
156-
use_rdma: bool,
156+
transport_type: str,
157157
):
158158
"""Given a "put" mesh shape and a "get" mesh shape.
159159
1. Create separate worlds for each mesh shape, running on different devices /PGs.
@@ -177,7 +177,7 @@ async def _test_resharding(
177177
178178
# Rank0: dtensor._local_tensor == [0,1], Rank1: dtensor._local_tensor == [2,3]
179179
"""
180-
os.environ["TORCHSTORE_RDMA_ENABLED"] = "1" if use_rdma else "0"
180+
set_transport_type(transport_type)
181181

182182
put_world_size = math.prod(put_mesh_shape)
183183
get_world_size = math.prod(get_mesh_shape)

tests/test_resharding_ext.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from torch.distributed._tensor import Shard
1212

1313
from .test_resharding_basic import _test_resharding
14-
from .utils import main, transport_plus_strategy_params
14+
from .utils import main, set_transport_type, transport_plus_strategy_params
1515

1616
logger = getLogger(__name__)
1717

@@ -45,7 +45,7 @@ def slow_tests_enabled():
4545
@pytest.mark.asyncio
4646
async def test_1d_resharding(
4747
strategy_params,
48-
use_rdma,
48+
transport_type,
4949
put_mesh_shape,
5050
get_mesh_shape,
5151
put_sharding_dim,
@@ -60,14 +60,14 @@ async def test_1d_resharding(
6060
get_mesh_shape=get_mesh_shape,
6161
get_placements=[Shard(get_sharding_dim)],
6262
strategy=strategy,
63-
use_rdma=use_rdma,
63+
transport_type=transport_type,
6464
)
6565

6666

6767
@requires_slow_tests_enabled
6868
@pytest.mark.parametrize(*transport_plus_strategy_params())
6969
@pytest.mark.asyncio
70-
async def test_2d_to_2d_resharding(strategy_params, use_rdma):
70+
async def test_2d_to_2d_resharding(strategy_params, transport_type):
7171
_, strategy = strategy_params
7272

7373
put_mesh_shape = get_mesh_shape = (2, 2)
@@ -83,13 +83,13 @@ async def test_2d_to_2d_resharding(strategy_params, use_rdma):
8383
get_mesh_shape=get_mesh_shape,
8484
get_placements=[Shard(dim) for dim in get_sharding_dims],
8585
strategy=strategy,
86-
use_rdma=use_rdma,
86+
transport_type=transport_type,
8787
)
8888

8989

9090
@pytest.mark.parametrize(*transport_plus_strategy_params())
9191
@pytest.mark.asyncio
92-
async def test_1d_to_2d_resharding(strategy_params, use_rdma):
92+
async def test_1d_to_2d_resharding(strategy_params, transport_type):
9393
_, strategy = strategy_params
9494

9595
put_mesh_shape = (4,)
@@ -106,13 +106,13 @@ async def test_1d_to_2d_resharding(strategy_params, use_rdma):
106106
get_mesh_shape=get_mesh_shape,
107107
get_placements=[Shard(dim) for dim in get_sharding_dims],
108108
strategy=strategy,
109-
use_rdma=use_rdma,
109+
transport_type=transport_type,
110110
)
111111

112112

113113
@pytest.mark.parametrize(*transport_plus_strategy_params())
114114
@pytest.mark.asyncio
115-
async def test_2d_to_1d_resharding(strategy_params, use_rdma):
115+
async def test_2d_to_1d_resharding(strategy_params, transport_type):
116116
_, strategy = strategy_params
117117

118118
put_mesh_shape = (2, 2)
@@ -129,7 +129,7 @@ async def test_2d_to_1d_resharding(strategy_params, use_rdma):
129129
get_mesh_shape=get_mesh_shape,
130130
get_placements=[Shard(dim) for dim in get_sharding_dims],
131131
strategy=strategy,
132-
use_rdma=use_rdma,
132+
transport_type=transport_type,
133133
)
134134

135135

tests/test_state_dict.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from torch.distributed.tensor import DTensor
2828
from torchstore.utils import spawn_actors
2929

30-
from .utils import main, transport_plus_strategy_params
30+
from .utils import main, set_transport_type, transport_plus_strategy_params
3131

3232
logger = getLogger(__name__)
3333

@@ -164,8 +164,8 @@ async def do_get(self):
164164

165165
@pytest.mark.parametrize(*transport_plus_strategy_params())
166166
@pytest.mark.asyncio
167-
async def test_state_dict(strategy_params, use_rdma):
168-
os.environ["TORCHSTORE_RDMA_ENABLED"] = "1" if use_rdma else "0"
167+
async def test_state_dict(strategy_params, transport_type):
168+
set_transport_type(transport_type)
169169

170170
class Trainer(Actor):
171171
# Monarch RDMA does not work outside of an actor, so we need
@@ -209,8 +209,8 @@ async def do_test(self):
209209
@pytest.mark.skip("TODO(kaiyuan-li@): fix this test")
210210
@pytest.mark.parametrize(*transport_plus_strategy_params())
211211
@pytest.mark.asyncio
212-
async def test_dcp_sharding_parity(strategy_params, use_rdma):
213-
os.environ["TORCHSTORE_RDMA_ENABLED"] = "1" if use_rdma else "0"
212+
async def test_dcp_sharding_parity(strategy_params, transport_type):
213+
set_transport_type(transport_type)
214214

215215
for save_mesh_shape, get_mesh_shape in [
216216
((2,), (4,)),

tests/test_store.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,17 @@
1414
from torchstore.logging import init_logging
1515
from torchstore.utils import spawn_actors
1616

17-
from .utils import main, transport_plus_strategy_params
17+
from .utils import main, set_transport_type, transport_plus_strategy_params
1818

1919
init_logging()
2020
logger = getLogger(__name__)
2121

2222

2323
@pytest.mark.parametrize(*transport_plus_strategy_params())
2424
@pytest.mark.asyncio
25-
async def test_basic(strategy_params, use_rdma):
25+
async def test_basic(strategy_params, transport_type):
2626
"""Test basic put/get functionality for multiple processes"""
27-
os.environ["TORCHSTORE_RDMA_ENABLED"] = "1" if use_rdma else "0"
27+
set_transport_type(transport_type)
2828

2929
class PutGetActor(Actor):
3030
"""Each instance of this actor represents a single process."""
@@ -83,9 +83,9 @@ async def get(self, rank_offset=0):
8383

8484
@pytest.mark.parametrize(*transport_plus_strategy_params())
8585
@pytest.mark.asyncio
86-
async def test_objects(strategy_params, use_rdma):
86+
async def test_objects(strategy_params, transport_type):
8787
"""Test put/get on arbitrary object"""
88-
os.environ["TORCHSTORE_RDMA_ENABLED"] = "1" if use_rdma else "0"
88+
set_transport_type(transport_type)
8989

9090
class ObjectActor(Actor):
9191
"""Each instance of this actor represents a single process."""
@@ -147,9 +147,9 @@ def __eq__(self, other: object) -> bool:
147147

148148
@pytest.mark.parametrize(*transport_plus_strategy_params())
149149
@pytest.mark.asyncio
150-
async def test_exists(strategy_params, use_rdma):
150+
async def test_exists(strategy_params, transport_type):
151151
"""Test the exists() API functionality"""
152-
os.environ["TORCHSTORE_RDMA_ENABLED"] = "1" if use_rdma else "0"
152+
set_transport_type(transport_type)
153153

154154
class ExistsTestActor(Actor):
155155
"""Actor for testing exists functionality."""
@@ -216,9 +216,9 @@ async def exists(self, key):
216216

217217
@pytest.mark.parametrize(*transport_plus_strategy_params())
218218
@pytest.mark.asyncio
219-
async def test_delete(strategy_params, use_rdma):
219+
async def test_delete(strategy_params, transport_type):
220220
"""Test the delete() API functionality"""
221-
os.environ["TORCHSTORE_RDMA_ENABLED"] = "1" if use_rdma else "0"
221+
set_transport_type(transport_type)
222222

223223
class DeleteTestActor(Actor):
224224
"""Actor for testing delete functionality."""

tests/test_tensor_slice.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,14 @@
1818
from torchstore.transport.pipe import TensorSlice
1919
from torchstore.utils import spawn_actors
2020

21-
from .utils import DTensorActor, main, transport_plus_strategy_params
21+
from .utils import DTensorActor, main, set_transport_type, transport_plus_strategy_params
2222

2323

2424
@pytest.mark.parametrize(*transport_plus_strategy_params())
2525
@pytest.mark.asyncio
26-
async def test_get_tensor_slice(strategy_params, use_rdma):
26+
async def test_get_tensor_slice(strategy_params, transport_type):
2727
"""Test tensor slice API functionality"""
28-
os.environ["TORCHSTORE_RDMA_ENABLED"] = "1" if use_rdma else "0"
28+
set_transport_type(transport_type)
2929

3030
class TensorSlicePutActor(Actor):
3131
"""Actor for putting tensors."""
@@ -204,7 +204,7 @@ async def get_tensor(self, key, tensor_slice_spec=None):
204204
return await ts.get(key, tensor_slice_spec=tensor_slice_spec)
205205

206206
# Use LocalRankStrategy with 2 storage volumes (no RDMA, no parametrization)
207-
os.environ["TORCHSTORE_RDMA_ENABLED"] = "0"
207+
set_transport_type("none")
208208
os.environ["LOCAL_RANK"] = "0" # Required by LocalRankStrategy
209209

210210
await ts.initialize(num_storage_volumes=2, strategy=ts.LocalRankStrategy())

0 commit comments

Comments
 (0)