Skip to content
Prev Previous commit
Next Next commit
feat: add device management for CUDA and CPU support across models
Allows systems without CUDA to fallback to CPU.
  • Loading branch information
provos committed Nov 28, 2025
commit 351ce7bb52fcfa734c110f2297849c19bd5fb2a4
3 changes: 2 additions & 1 deletion sam3/model/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,8 +277,9 @@ def __init__(

if resolution is not None and stride is not None:
feat_size = resolution // stride
device = "cuda" if torch.cuda.is_available() else "cpu"
coords_h, coords_w = self._get_coords(
feat_size, feat_size, device="cuda"
feat_size, feat_size, device=device
)
self.compilable_cord_cache = (coords_h, coords_w)
self.compilable_stored_size = (feat_size, feat_size)
Expand Down
124 changes: 79 additions & 45 deletions sam3/model/io_utils.py

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion sam3/model/position_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,9 @@ def __init__(
(precompute_resolution // 16, precompute_resolution // 16),
(precompute_resolution // 32, precompute_resolution // 32),
]
device = "cuda" if torch.cuda.is_available() else "cpu"
for size in precompute_sizes:
tensors = torch.zeros((1, 1) + size, device="cuda")
tensors = torch.zeros((1, 1) + size, device=device)
self.forward(tensors)
# further clone and detach it in the cache (just to be safe)
self.cache[size] = self.cache[size].clone().detach()
Expand Down
4 changes: 3 additions & 1 deletion sam3/model/sam3_image_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@
class Sam3Processor:
""" """

def __init__(self, model, resolution=1008, device="cuda", confidence_threshold=0.5):
def __init__(self, model, resolution=1008, device=None, confidence_threshold=0.5):
self.model = model
self.resolution = resolution
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = device
self.transform = v2.Compose(
[
Expand Down
13 changes: 12 additions & 1 deletion sam3/model/sam3_tracker_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,14 @@
import torch.nn.functional as F
from numpy.typing import NDArray

from sam3.model.edt import edt_triton
# Triton is only available on CUDA (not Apple Silicon/MPS)
try:
from sam3.model.edt import edt_triton

_HAS_TRITON = True
except ImportError:
_HAS_TRITON = False
edt_triton = None


def sample_box_points(
Expand Down Expand Up @@ -148,6 +155,10 @@ def sample_one_point_from_error_center(gt_masks, pred_masks, padding=True):
- points: [B, 1, 2], dtype=torch.float, contains (x, y) coordinates of each sampled point
- labels: [B, 1], dtype=torch.int32, where 1 means positive clicks and 0 means negative clicks
"""
# Fall back to slow (OpenCV-based) implementation if Triton is not available
if not _HAS_TRITON or not gt_masks.is_cuda:
return sample_one_point_from_error_center_slow(gt_masks, pred_masks, padding)

if pred_masks is None:
pred_masks = torch.zeros_like(gt_masks)
assert gt_masks.dtype == torch.bool and gt_masks.size(1) == 1
Expand Down
29 changes: 17 additions & 12 deletions sam3/model/sam3_tracking_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,11 @@ def __init__(
self.max_point_num_in_prompt_enc = max_point_num_in_prompt_enc
self.non_overlap_masks_for_output = non_overlap_masks_for_output

self.bf16_context = torch.autocast(device_type="cuda", dtype=torch.bfloat16)
self.bf16_context.__enter__() # keep using for the entire model process
if torch.cuda.is_available():
self.bf16_context = torch.autocast(device_type="cuda", dtype=torch.bfloat16)
self.bf16_context.__enter__() # keep using for the entire model process
else:
self.bf16_context = None

self.iter_use_prev_mask_pred = True
self.add_all_frames_to_correct_as_cond = True
Expand Down Expand Up @@ -75,7 +78,7 @@ def init_state(
# and from 24 to 21 when tracking two objects)
inference_state["offload_state_to_cpu"] = offload_state_to_cpu
inference_state["device"] = self.device
if offload_state_to_cpu:
if offload_state_to_cpu or not torch.cuda.is_available():
inference_state["storage_device"] = torch.device("cpu")
else:
inference_state["storage_device"] = torch.device("cuda")
Expand Down Expand Up @@ -300,7 +303,7 @@ def add_new_points_or_box(
prev_out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx)

if prev_out is not None and prev_out["pred_masks"] is not None:
prev_sam_mask_logits = prev_out["pred_masks"].cuda(non_blocking=True)
prev_sam_mask_logits = prev_out["pred_masks"].to(inference_state["device"], non_blocking=torch.cuda.is_available())
# Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues.
prev_sam_mask_logits = torch.clamp(prev_sam_mask_logits, -32.0, 32.0)
current_out, _ = self._run_single_frame_inference(
Expand Down Expand Up @@ -469,7 +472,7 @@ def _get_orig_video_res_output(self, inference_state, any_res_masks):
device = inference_state["device"]
video_H = inference_state["video_height"]
video_W = inference_state["video_width"]
any_res_masks = any_res_masks.to(device, non_blocking=True)
any_res_masks = any_res_masks.to(device, non_blocking=torch.cuda.is_available())
if any_res_masks.shape[-2:] == (video_H, video_W):
video_res_masks = any_res_masks
else:
Expand Down Expand Up @@ -609,7 +612,7 @@ def _consolidate_temp_output_across_obj(
if run_mem_encoder:
device = inference_state["device"]
high_res_masks = torch.nn.functional.interpolate(
consolidated_out["pred_masks"].to(device, non_blocking=True),
consolidated_out["pred_masks"].to(device, non_blocking=torch.cuda.is_available()),
size=(self.image_size, self.image_size),
mode="bilinear",
align_corners=False,
Expand Down Expand Up @@ -1023,7 +1026,7 @@ def _get_image_feature(self, inference_state, frame_idx, batch_size):
)
else:
# Cache miss -- we will run inference on a single image
image = inference_state["images"][frame_idx].cuda().float().unsqueeze(0)
image = inference_state["images"][frame_idx].to(inference_state["device"]).float().unsqueeze(0)
backbone_out = self.forward_image(image)
# Cache the most recent frame's feature (for repeated interactions with
# a frame; we can use an LRU cache for more frames in the future).
Expand Down Expand Up @@ -1095,10 +1098,11 @@ def _run_single_frame_inference(
storage_device = inference_state["storage_device"]
maskmem_features = current_out["maskmem_features"]
if maskmem_features is not None:
maskmem_features = maskmem_features.to(torch.bfloat16)
maskmem_features = maskmem_features.to(storage_device, non_blocking=True)
if torch.cuda.is_available():
maskmem_features = maskmem_features.to(torch.bfloat16)
maskmem_features = maskmem_features.to(storage_device, non_blocking=torch.cuda.is_available())
pred_masks_gpu = current_out["pred_masks"]
pred_masks = pred_masks_gpu.to(storage_device, non_blocking=True)
pred_masks = pred_masks_gpu.to(storage_device, non_blocking=torch.cuda.is_available())
# "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
maskmem_pos_enc = self._get_maskmem_pos_enc(inference_state, current_out)
# object pointer is a small tensor, so we always keep it on GPU memory for fast access
Expand Down Expand Up @@ -1146,8 +1150,9 @@ def _run_memory_encoder(

# optionally offload the output to CPU memory to save GPU space
storage_device = inference_state["storage_device"]
maskmem_features = maskmem_features.to(torch.bfloat16)
maskmem_features = maskmem_features.to(storage_device, non_blocking=True)
if torch.cuda.is_available():
maskmem_features = maskmem_features.to(torch.bfloat16)
maskmem_features = maskmem_features.to(storage_device, non_blocking=torch.cuda.is_available())
# "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
maskmem_pos_enc = self._get_maskmem_pos_enc(
inference_state, {"maskmem_pos_enc": maskmem_pos_enc}
Expand Down
38 changes: 26 additions & 12 deletions sam3/model/sam3_video_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,14 @@ def __init__(
self.video_loader_type = video_loader_type
from sam3.model_builder import build_sam3_video_model

# Determine device
if torch.cuda.is_available():
self.device = torch.device("cuda")
else:
self.device = torch.device("cpu")

logger.info(f"Sam3VideoPredictor using device: {self.device}")

self.model = (
build_sam3_video_model(
checkpoint_path=checkpoint_path,
Expand All @@ -48,7 +56,7 @@ def __init__(
strict_state_dict_loading=strict_state_dict_loading,
apply_temporal_disambiguation=apply_temporal_disambiguation,
)
.cuda()
.to(self.device)
.eval()
)

Expand Down Expand Up @@ -265,21 +273,27 @@ def _get_session_stats(self):
f"'{session_id}' ({session['state']['num_frames']} frames)"
for session_id, session in self._ALL_INFERENCE_STATES.items()
]
session_stats_str = (
f"live sessions: [{', '.join(live_session_strs)}], GPU memory: "
f"{torch.cuda.memory_allocated() // 1024**2} MiB used and "
f"{torch.cuda.memory_reserved() // 1024**2} MiB reserved"
f" (max over time: {torch.cuda.max_memory_allocated() // 1024**2} MiB used "
f"and {torch.cuda.max_memory_reserved() // 1024**2} MiB reserved)"
)
if torch.cuda.is_available():
mem_stats = (
f"GPU memory: {torch.cuda.memory_allocated() // 1024**2} MiB used and "
f"{torch.cuda.memory_reserved() // 1024**2} MiB reserved"
f" (max over time: {torch.cuda.max_memory_allocated() // 1024**2} MiB used "
f"and {torch.cuda.max_memory_reserved() // 1024**2} MiB reserved)"
)
else:
mem_stats = "Running on CPU"
session_stats_str = f"live sessions: [{', '.join(live_session_strs)}], {mem_stats}"
return session_stats_str

def _get_torch_and_gpu_properties(self):
"""Get a string for PyTorch and GPU properties (for logging and debugging)."""
torch_and_gpu_str = (
f"torch: {torch.__version__} with CUDA arch {torch.cuda.get_arch_list()}, "
f"GPU device: {torch.cuda.get_device_properties(torch.cuda.current_device())}"
)
if torch.cuda.is_available():
torch_and_gpu_str = (
f"torch: {torch.__version__} with CUDA arch {torch.cuda.get_arch_list()}, "
f"GPU device: {torch.cuda.get_device_properties(torch.cuda.current_device())}"
)
else:
torch_and_gpu_str = f"torch: {torch.__version__} (CPU mode)"
return torch_and_gpu_str

def shutdown(self):
Expand Down
2 changes: 1 addition & 1 deletion sam3/model/utils/sam2_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def __getitem__(self, index):
img -= self.img_mean
img /= self.img_std
if not self.offload_video_to_cpu:
img = img.to(self.compute_device, non_blocking=True)
img = img.to(self.compute_device, non_blocking=torch.cuda.is_available())
self.images[index] = img
return img

Expand Down
8 changes: 6 additions & 2 deletions sam3/model/vl_combiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,10 @@ def _forward_image_no_act_ckpt(self, samples):
return output

def forward_text(
self, captions, input_boxes=None, additional_text=None, device="cuda"
self, captions, input_boxes=None, additional_text=None, device=None
):
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
return activation_ckpt_wrapper(self._forward_text_no_ack_ckpt)(
captions=captions,
input_boxes=input_boxes,
Expand All @@ -134,8 +136,10 @@ def _forward_text_no_ack_ckpt(
captions,
input_boxes=None,
additional_text=None,
device="cuda",
device=None,
):
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
output = {}

# Forward through text_encoder
Expand Down
9 changes: 7 additions & 2 deletions sam3/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from sam3.model.sam3_image import Sam3Image, Sam3ImageOnVideoMultiGPU
from sam3.model.sam3_tracking_predictor import Sam3TrackerPredictor
from sam3.model.sam3_video_inference import Sam3VideoInferenceWithInstanceInteractivity
from sam3.model.sam3_video_predictor import Sam3VideoPredictorMultiGPU
from sam3.model.sam3_video_predictor import Sam3VideoPredictor, Sam3VideoPredictorMultiGPU
from sam3.model.text_encoder_ve import VETextEncoder
from sam3.model.tokenizer_ve import SimpleTokenizer
from sam3.model.vitdet import ViT
Expand Down Expand Up @@ -547,8 +547,10 @@ def _load_checkpoint(model, checkpoint_path):

def _setup_device_and_mode(model, device, eval_mode):
"""Setup model device and evaluation mode."""
if device == "cuda":
if device == "cuda" and torch.cuda.is_available():
model = model.cuda()
elif device != "cpu":
model = model.to(device)
if eval_mode:
model.eval()
return model
Expand Down Expand Up @@ -788,6 +790,9 @@ def build_sam3_video_model(


def build_sam3_video_predictor(*model_args, gpus_to_use=None, **model_kwargs):
# Use single-device predictor on CPU, multi-GPU predictor only when CUDA is available
if not torch.cuda.is_available():
return Sam3VideoPredictor(*model_args, **model_kwargs)
return Sam3VideoPredictorMultiGPU(
*model_args, gpus_to_use=gpus_to_use, **model_kwargs
)