Skip to content

Commit f7af51d

Browse files
author
Alex Sablayrolles
committed
[ahem] Committing embedding bag tests
1 parent a89c594 commit f7af51d

File tree

1 file changed

+79
-0
lines changed

1 file changed

+79
-0
lines changed
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import hypothesis.strategies as st
17+
import torch
18+
import torch.nn as nn
19+
from hypothesis import given, settings
20+
import torch.nn.functional as F
21+
from opacus.grad_sample import wrap_model
22+
23+
from .common import GradSampleHooks_test
24+
25+
26+
class Embedding_bag_test(GradSampleHooks_test):
27+
@given(
28+
N=st.integers(4, 8),
29+
sz=st.integers(3, 7),
30+
V=st.integers(10, 32),
31+
D=st.integers(10, 17),
32+
mode=st.sampled_from(["sum", "mean"]),
33+
)
34+
@settings(deadline=10000)
35+
def test_input_across_dims(
36+
self,
37+
N: int,
38+
sz: int,
39+
V: int,
40+
D: int,
41+
mode: str,
42+
):
43+
emb = nn.EmbeddingBag(num_embeddings=V, embedding_dim=D, mode=mode)
44+
45+
sizes = torch.randint(low=1, high=sz + 1, size=(N,))
46+
offsets = torch.LongTensor([0] + torch.cumsum(sizes, dim=0).tolist()[:-1])
47+
input = []
48+
for size in sizes:
49+
input += [torch.randperm(V)[:size]]
50+
51+
input = torch.cat(input, dim=0)
52+
# target = torch.randn(N, D)
53+
54+
# output = emb(input, offsets)
55+
# loss = F.mse_loss(output, target)
56+
# loss.backward()
57+
58+
# # Compute microbatch
59+
# grad_microbatches = []
60+
# for i in range(N):
61+
# emb.zero_grad()
62+
# output = emb(input[offsets[i] : offsets[i] + sizes[i]], None)
63+
# loss = F.mse_loss(output, target[i])
64+
# loss.backward()
65+
# grad_microbatches.append()
66+
67+
# import pdb;pdb.set_trace()
68+
69+
def chunk_method(x):
70+
input, offsets = x
71+
for i_offset, offset in enumerate(offsets):
72+
if i_offset < len(offsets) - 1:
73+
next_offset = offsets[i_offset + 1]
74+
else:
75+
next_offset = len(input)
76+
yield (input[offset:next_offset], torch.LongTensor([0]))
77+
print(N, sz, V, D, mode)
78+
print(input, offsets)
79+
self.run_test((input, offsets), emb, chunk_method=chunk_method)

0 commit comments

Comments
 (0)