Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from torchstore.utils import spawn_actors
from transformers import AutoModelForCausalLM

from .utils import main, transport_plus_strategy_params
from .utils import main, set_transport_type, transport_plus_strategy_params

logger = getLogger(__name__)

Expand Down Expand Up @@ -120,24 +120,24 @@ async def do_get(self):

@pytest.mark.parametrize(*transport_plus_strategy_params())
@pytest.mark.asyncio
async def test_basic(strategy_params, use_rdma):
async def test_basic(strategy_params, transport_type):
# FSDP
put_mesh_shape = (1,)
get_mesh_shape = (1,)
await _do_test(put_mesh_shape, get_mesh_shape, strategy_params[1], use_rdma)
await _do_test(put_mesh_shape, get_mesh_shape, strategy_params[1], transport_type)


@pytest.mark.parametrize(*transport_plus_strategy_params())
@pytest.mark.asyncio
async def test_resharding(strategy_params, use_rdma):
async def test_resharding(strategy_params, transport_type):
# FSDP
put_mesh_shape = (4,)
get_mesh_shape = (8,)
await _do_test(put_mesh_shape, get_mesh_shape, strategy_params[1], use_rdma)
await _do_test(put_mesh_shape, get_mesh_shape, strategy_params[1], transport_type)


async def _do_test(put_mesh_shape, get_mesh_shape, strategy, use_rdma):
os.environ["TORCHSTORE_RDMA_ENABLED"] = "1" if use_rdma else "0"
async def _do_test(put_mesh_shape, get_mesh_shape, strategy, transport_type):
set_transport_type(transport_type)

ts.init_logging()
logger.info(f"Testing with strategy: {strategy}")
Expand Down
28 changes: 14 additions & 14 deletions tests/test_resharding_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from torch.distributed.tensor._utils import _compute_local_shape_and_global_offset
from torchstore.utils import get_local_tensor, spawn_actors

from .utils import DTensorActor, main, transport_plus_strategy_params
from .utils import DTensorActor, main, set_transport_type, transport_plus_strategy_params

logger = getLogger(__name__)

Expand All @@ -35,7 +35,7 @@
@pytest.mark.asyncio
async def test_1d_resharding(
strategy_params,
use_rdma,
transport_type,
put_mesh_shape,
get_mesh_shape,
put_sharding_dim,
Expand All @@ -50,13 +50,13 @@ async def test_1d_resharding(
get_mesh_shape=get_mesh_shape,
get_placements=[Shard(get_sharding_dim)],
strategy=strategy,
use_rdma=use_rdma,
transport_type=transport_type,
)


@pytest.mark.parametrize(*transport_plus_strategy_params())
@pytest.mark.asyncio
async def test_2d_to_2d_resharding(strategy_params, use_rdma):
async def test_2d_to_2d_resharding(strategy_params, transport_type):
_, strategy = strategy_params

put_mesh_shape = get_mesh_shape = (2, 2)
Expand All @@ -69,13 +69,13 @@ async def test_2d_to_2d_resharding(strategy_params, use_rdma):
get_mesh_shape=get_mesh_shape,
get_placements=[Shard(dim) for dim in get_sharding_dims],
strategy=strategy,
use_rdma=use_rdma,
transport_type=transport_type,
)


@pytest.mark.parametrize(*transport_plus_strategy_params())
@pytest.mark.asyncio
async def test_1d_to_2d_resharding(strategy_params, use_rdma):
async def test_1d_to_2d_resharding(strategy_params, transport_type):
_, strategy = strategy_params

put_mesh_shape = (4,)
Expand All @@ -89,13 +89,13 @@ async def test_1d_to_2d_resharding(strategy_params, use_rdma):
get_mesh_shape=get_mesh_shape,
get_placements=[Shard(dim) for dim in get_sharding_dims],
strategy=strategy,
use_rdma=use_rdma,
transport_type=transport_type,
)


@pytest.mark.parametrize(*transport_plus_strategy_params())
@pytest.mark.asyncio
async def test_2d_to_1d_resharding(strategy_params, use_rdma):
async def test_2d_to_1d_resharding(strategy_params, transport_type):
_, strategy = strategy_params

put_mesh_shape = (2, 2)
Expand All @@ -109,13 +109,13 @@ async def test_2d_to_1d_resharding(strategy_params, use_rdma):
get_mesh_shape=get_mesh_shape,
get_placements=[Shard(dim) for dim in get_sharding_dims],
strategy=strategy,
use_rdma=use_rdma,
transport_type=transport_type,
)


@pytest.mark.parametrize(*transport_plus_strategy_params())
@pytest.mark.asyncio
async def test_data_parallel(strategy_params, use_rdma):
async def test_data_parallel(strategy_params, transport_type):
_, strategy = strategy_params

# # 1d
Expand All @@ -128,7 +128,7 @@ async def test_data_parallel(strategy_params, use_rdma):
get_mesh_shape=get_mesh_shape,
get_placements=placements,
strategy=strategy,
use_rdma=use_rdma,
transport_type=transport_type,
)

# 2d -> 1d
Expand All @@ -143,7 +143,7 @@ async def test_data_parallel(strategy_params, use_rdma):
get_mesh_shape=get_mesh_shape,
get_placements=[Shard(1)],
strategy=strategy,
use_rdma=use_rdma,
transport_type=transport_type,
)


Expand All @@ -153,7 +153,7 @@ async def _test_resharding(
get_mesh_shape: Tuple[int],
get_placements: List[Union[Replicate, Shard]],
strategy: ts.TorchStoreStrategy,
use_rdma: bool,
transport_type: str,
):
"""Given a "put" mesh shape and a "get" mesh shape.
1. Create separate worlds for each mesh shape, running on different devices /PGs.
Expand All @@ -177,7 +177,7 @@ async def _test_resharding(

# Rank0: dtensor._local_tensor == [0,1], Rank1: dtensor._local_tensor == [2,3]
"""
os.environ["TORCHSTORE_RDMA_ENABLED"] = "1" if use_rdma else "0"
set_transport_type(transport_type)

put_world_size = math.prod(put_mesh_shape)
get_world_size = math.prod(get_mesh_shape)
Expand Down
18 changes: 9 additions & 9 deletions tests/test_resharding_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from torch.distributed._tensor import Shard

from .test_resharding_basic import _test_resharding
from .utils import main, transport_plus_strategy_params
from .utils import main, set_transport_type, transport_plus_strategy_params

logger = getLogger(__name__)

Expand Down Expand Up @@ -45,7 +45,7 @@ def slow_tests_enabled():
@pytest.mark.asyncio
async def test_1d_resharding(
strategy_params,
use_rdma,
transport_type,
put_mesh_shape,
get_mesh_shape,
put_sharding_dim,
Expand All @@ -60,14 +60,14 @@ async def test_1d_resharding(
get_mesh_shape=get_mesh_shape,
get_placements=[Shard(get_sharding_dim)],
strategy=strategy,
use_rdma=use_rdma,
transport_type=transport_type,
)


@requires_slow_tests_enabled
@pytest.mark.parametrize(*transport_plus_strategy_params())
@pytest.mark.asyncio
async def test_2d_to_2d_resharding(strategy_params, use_rdma):
async def test_2d_to_2d_resharding(strategy_params, transport_type):
_, strategy = strategy_params

put_mesh_shape = get_mesh_shape = (2, 2)
Expand All @@ -83,13 +83,13 @@ async def test_2d_to_2d_resharding(strategy_params, use_rdma):
get_mesh_shape=get_mesh_shape,
get_placements=[Shard(dim) for dim in get_sharding_dims],
strategy=strategy,
use_rdma=use_rdma,
transport_type=transport_type,
)


@pytest.mark.parametrize(*transport_plus_strategy_params())
@pytest.mark.asyncio
async def test_1d_to_2d_resharding(strategy_params, use_rdma):
async def test_1d_to_2d_resharding(strategy_params, transport_type):
_, strategy = strategy_params

put_mesh_shape = (4,)
Expand All @@ -106,13 +106,13 @@ async def test_1d_to_2d_resharding(strategy_params, use_rdma):
get_mesh_shape=get_mesh_shape,
get_placements=[Shard(dim) for dim in get_sharding_dims],
strategy=strategy,
use_rdma=use_rdma,
transport_type=transport_type,
)


@pytest.mark.parametrize(*transport_plus_strategy_params())
@pytest.mark.asyncio
async def test_2d_to_1d_resharding(strategy_params, use_rdma):
async def test_2d_to_1d_resharding(strategy_params, transport_type):
_, strategy = strategy_params

put_mesh_shape = (2, 2)
Expand All @@ -129,7 +129,7 @@ async def test_2d_to_1d_resharding(strategy_params, use_rdma):
get_mesh_shape=get_mesh_shape,
get_placements=[Shard(dim) for dim in get_sharding_dims],
strategy=strategy,
use_rdma=use_rdma,
transport_type=transport_type,
)


Expand Down
10 changes: 5 additions & 5 deletions tests/test_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from torch.distributed.tensor import DTensor
from torchstore.utils import spawn_actors

from .utils import main, transport_plus_strategy_params
from .utils import main, set_transport_type, transport_plus_strategy_params

logger = getLogger(__name__)

Expand Down Expand Up @@ -164,8 +164,8 @@ async def do_get(self):

@pytest.mark.parametrize(*transport_plus_strategy_params())
@pytest.mark.asyncio
async def test_state_dict(strategy_params, use_rdma):
os.environ["TORCHSTORE_RDMA_ENABLED"] = "1" if use_rdma else "0"
async def test_state_dict(strategy_params, transport_type):
set_transport_type(transport_type)

class Trainer(Actor):
# Monarch RDMA does not work outside of an actor, so we need
Expand Down Expand Up @@ -209,8 +209,8 @@ async def do_test(self):
@pytest.mark.skip("TODO(kaiyuan-li@): fix this test")
@pytest.mark.parametrize(*transport_plus_strategy_params())
@pytest.mark.asyncio
async def test_dcp_sharding_parity(strategy_params, use_rdma):
os.environ["TORCHSTORE_RDMA_ENABLED"] = "1" if use_rdma else "0"
async def test_dcp_sharding_parity(strategy_params, transport_type):
set_transport_type(transport_type)

for save_mesh_shape, get_mesh_shape in [
((2,), (4,)),
Expand Down
18 changes: 9 additions & 9 deletions tests/test_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,17 @@
from torchstore.logging import init_logging
from torchstore.utils import spawn_actors

from .utils import main, transport_plus_strategy_params
from .utils import main, set_transport_type, transport_plus_strategy_params

init_logging()
logger = getLogger(__name__)


@pytest.mark.parametrize(*transport_plus_strategy_params())
@pytest.mark.asyncio
async def test_basic(strategy_params, use_rdma):
async def test_basic(strategy_params, transport_type):
"""Test basic put/get functionality for multiple processes"""
os.environ["TORCHSTORE_RDMA_ENABLED"] = "1" if use_rdma else "0"
set_transport_type(transport_type)

class PutGetActor(Actor):
"""Each instance of this actor represents a single process."""
Expand Down Expand Up @@ -83,9 +83,9 @@ async def get(self, rank_offset=0):

@pytest.mark.parametrize(*transport_plus_strategy_params())
@pytest.mark.asyncio
async def test_objects(strategy_params, use_rdma):
async def test_objects(strategy_params, transport_type):
"""Test put/get on arbitrary object"""
os.environ["TORCHSTORE_RDMA_ENABLED"] = "1" if use_rdma else "0"
set_transport_type(transport_type)

class ObjectActor(Actor):
"""Each instance of this actor represents a single process."""
Expand Down Expand Up @@ -147,9 +147,9 @@ def __eq__(self, other: object) -> bool:

@pytest.mark.parametrize(*transport_plus_strategy_params())
@pytest.mark.asyncio
async def test_exists(strategy_params, use_rdma):
async def test_exists(strategy_params, transport_type):
"""Test the exists() API functionality"""
os.environ["TORCHSTORE_RDMA_ENABLED"] = "1" if use_rdma else "0"
set_transport_type(transport_type)

class ExistsTestActor(Actor):
"""Actor for testing exists functionality."""
Expand Down Expand Up @@ -216,9 +216,9 @@ async def exists(self, key):

@pytest.mark.parametrize(*transport_plus_strategy_params())
@pytest.mark.asyncio
async def test_delete(strategy_params, use_rdma):
async def test_delete(strategy_params, transport_type):
"""Test the delete() API functionality"""
os.environ["TORCHSTORE_RDMA_ENABLED"] = "1" if use_rdma else "0"
set_transport_type(transport_type)

class DeleteTestActor(Actor):
"""Actor for testing delete functionality."""
Expand Down
8 changes: 4 additions & 4 deletions tests/test_tensor_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@
from torchstore.transport.pipe import TensorSlice
from torchstore.utils import spawn_actors

from .utils import DTensorActor, main, transport_plus_strategy_params
from .utils import DTensorActor, main, set_transport_type, transport_plus_strategy_params


@pytest.mark.parametrize(*transport_plus_strategy_params())
@pytest.mark.asyncio
async def test_get_tensor_slice(strategy_params, use_rdma):
async def test_get_tensor_slice(strategy_params, transport_type):
"""Test tensor slice API functionality"""
os.environ["TORCHSTORE_RDMA_ENABLED"] = "1" if use_rdma else "0"
set_transport_type(transport_type)

class TensorSlicePutActor(Actor):
"""Actor for putting tensors."""
Expand Down Expand Up @@ -204,7 +204,7 @@ async def get_tensor(self, key, tensor_slice_spec=None):
return await ts.get(key, tensor_slice_spec=tensor_slice_spec)

# Use LocalRankStrategy with 2 storage volumes (no RDMA, no parametrization)
os.environ["TORCHSTORE_RDMA_ENABLED"] = "0"
set_transport_type("none")
os.environ["LOCAL_RANK"] = "0" # Required by LocalRankStrategy

await ts.initialize(num_storage_volumes=2, strategy=ts.LocalRankStrategy())
Expand Down
Loading