Skip to content

Commit fc6f1dc

Browse files
Sparsity Preserving DP-SGD in TF Privacy [4 of 5]
Add contribution count function for embedding layer. See https://research.google/blog/sparsity-preserving-differentially-private-training/ for more details on the algorithm. PiperOrigin-RevId: 656091009
1 parent 80802c2 commit fc6f1dc

File tree

6 files changed

+531
-2
lines changed

6 files changed

+531
-2
lines changed

tensorflow_privacy/privacy/sparsity_preserving_noise/BUILD

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,5 +21,8 @@ py_library(
2121
py_library(
2222
name = "layer_registry",
2323
srcs = ["layer_registry.py"],
24-
deps = [":type_aliases"],
24+
deps = [
25+
":type_aliases",
26+
"//tensorflow_privacy/privacy/sparsity_preserving_noise/registry_functions:embedding",
27+
],
2528
)

tensorflow_privacy/privacy/sparsity_preserving_noise/layer_registry.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import tensorflow as tf
1919
from tensorflow_privacy.privacy.sparsity_preserving_noise import type_aliases
20+
from tensorflow_privacy.privacy.sparsity_preserving_noise.registry_functions import embedding
2021

2122

2223
# ==============================================================================
@@ -49,3 +50,15 @@ def insert(
4950
layer_key = hash(layer_class)
5051
self._layer_class_dict[layer_key] = layer_class
5152
self._registry[layer_key] = layer_registry_function
53+
54+
55+
# ==============================================================================
56+
# Main factory methods
57+
# ==============================================================================
58+
def make_default_layer_registry() -> LayerRegistry:
59+
registry = LayerRegistry()
60+
registry.insert(
61+
tf.keras.layers.Embedding,
62+
embedding.embedding_layer_contribution_histogram,
63+
)
64+
return registry
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
package(default_visibility = ["//visibility:public"])
2+
3+
licenses(["notice"])
4+
5+
py_library(
6+
name = "embedding",
7+
srcs = ["embedding.py"],
8+
deps = ["//tensorflow_privacy/privacy/sparsity_preserving_noise:type_aliases"],
9+
)
10+
11+
py_test(
12+
name = "embedding_test",
13+
srcs = ["embedding_test.py"],
14+
deps = [":embedding"],
15+
)
Lines changed: 228 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,228 @@
1+
# Copyright 2024, The TensorFlow Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Compute the contribution histogram for an embedding layer."""
15+
16+
from typing import Optional
17+
import tensorflow as tf
18+
from tensorflow_privacy.privacy.sparsity_preserving_noise import type_aliases
19+
20+
21+
def embedding_layer_contribution_histogram(
22+
layer_instance: tf.keras.layers.Embedding,
23+
input_args: type_aliases.InputArgs,
24+
input_kwargs: type_aliases.InputKwargs,
25+
num_microbatches: Optional[tf.Tensor] = None,
26+
) -> dict[str, type_aliases.ContributionCountHistogramFn]:
27+
"""Registry function for `tf.keras.layers.Embedding`.
28+
29+
Args:
30+
layer_instance: A `tf.keras.layers.Embedding` instance.
31+
input_args: A `tuple` containing the first part of `layer_instance` input.
32+
Specifically, `layer_instance(*inputs_args, **input_kwargs)` should return
33+
a valid output.
34+
input_kwargs: A `tuple` containing the second part of `layer_instance`
35+
input. Specifically, `layer_instance(*inputs_args, **input_kwargs)` should
36+
return a valid output.
37+
num_microbatches: An optional numeric value or scalar `tf.Tensor` for
38+
indicating whether and how the losses are grouped into microbatches. If
39+
not None, num_microbatches must divide the batch size.
40+
41+
Returns:
42+
A dict mapping the name of the trainable variable to a function with
43+
signature `(tf.IndexedSlices) -> tf.SparseTensor`. The function takes a
44+
`tf.IndexedSlices` object representing the gradient for that variable and
45+
returns a `tf.SparseTensor` representing the normalized (so that each user
46+
contributes 1) contribution counts histogram per user for each embedding
47+
vector.
48+
"""
49+
if input_kwargs:
50+
raise ValueError("Embedding layer calls should not receive kwargs.")
51+
del input_kwargs # Unused in embedding layer calls.
52+
if not input_args or len(input_args) != 1:
53+
raise ValueError("Only layer inputs of length 1 are permitted.")
54+
if hasattr(layer_instance, "sparse"): # for backwards compatibility
55+
if layer_instance.sparse:
56+
raise NotImplementedError("Sparse output tensors are not supported.")
57+
if isinstance(input_args[0], tf.SparseTensor):
58+
raise NotImplementedError("Sparse input tensors are not supported.")
59+
60+
# Disable experimental features.
61+
if hasattr(layer_instance, "_use_one_hot_matmul"):
62+
if layer_instance._use_one_hot_matmul: # pylint: disable=protected-access
63+
raise NotImplementedError(
64+
"The experimental embedding feature "
65+
"'_use_one_hot_matmul' is not supported."
66+
)
67+
input_ids = tf.squeeze(tf.cast(*input_args, tf.int32))
68+
69+
def count_contributions_fn(
70+
grad: type_aliases.SparseGradient,
71+
) -> type_aliases.ContributionCountHistogram:
72+
return embedding_layer_contribution_histogram_fn(
73+
grad,
74+
input_ids,
75+
layer_instance.input_dim,
76+
num_microbatches,
77+
)
78+
79+
if (
80+
not layer_instance.trainable_variables
81+
or len(layer_instance.trainable_variables) != 1
82+
):
83+
raise ValueError(
84+
"Embedding layer must have exactly one trainable variable."
85+
)
86+
return {layer_instance.trainable_variables[0].name: count_contributions_fn}
87+
88+
89+
def embedding_layer_contribution_histogram_fn(
90+
grad: type_aliases.SparseGradient,
91+
input_ids: tf.Tensor,
92+
vocab_size: Optional[tf.Tensor],
93+
num_microbatches: Optional[tf.Tensor] = None,
94+
) -> type_aliases.ContributionCountHistogram:
95+
"""Computes the normalized contribution counts histogram for embedding layer.
96+
97+
NOTE: to help understand the code, we document in the function body what the
98+
expected intermediate variables are for the below running example:
99+
100+
grad = None
101+
input_ids = [[1, 1, 2], [0], [2, 0]]
102+
vocab_size = 3
103+
num_microbatches = None
104+
105+
For ease of reference, we also list these variables below:
106+
107+
row_indices = [[0], [0], [0], [1], [2], [2]]
108+
flattened_indices = [[1], [1], [2], [0], [2], [0]]
109+
paired_indices = [[0, 1], [0, 1], [0, 2], [1, 0], [2, 2], [2, 0]]
110+
linearized_pair_indices = [1 1 2 3 8 6]
111+
contribution_counts_linearized_indices = [1 2 3 8 6]
112+
contribution_counts_indices = [[0, 1], [0, 2], [1, 0], [2, 2], [2, 0]]
113+
contribution_counts_values = [2 1 1 1 1]
114+
user_normalized_contribution_counts = tf.SparseTensor(
115+
indices=[[0, 1], [0, 2], [1, 0], [2, 0], [2, 2]],
116+
values=[0.67, 0.33, 1., 0.5, 0.5,]
117+
shape=(3, 3)
118+
)
119+
contribution_histogram = tf.SparseTensor(
120+
indices=[[0], [1], [2]],
121+
values=[1.5, 0.67, 0.83],
122+
shape=(3,)
123+
)
124+
125+
126+
Args:
127+
grad: The gradient of the layer. (unused for embedding layer)
128+
input_ids: The input ids used to compute the embeddings.
129+
vocab_size: The vocabulary size of the embedding layer.
130+
num_microbatches: An optional numeric value or scalar `tf.Tensor` for
131+
indicating whether and how the losses are grouped into microbatches. If
132+
not None, num_microbatches must divide the batch size.
133+
134+
Returns:
135+
A `tf.SparseTensor` representing the normalized (so that each user
136+
contributes 1) contribution counts histogram per user for each embedding
137+
vector.
138+
139+
Raises:
140+
NotImplementedError: If the input_ids is not a `tf.Tensor` or
141+
`tf.RaggedTensor`.
142+
"""
143+
del grad # unused.
144+
145+
nrows = tf.shape(input_ids)[0]
146+
if isinstance(input_ids, tf.RaggedTensor):
147+
row_indices = tf.expand_dims(
148+
input_ids.merge_dims(1, -1).value_rowids(), axis=-1
149+
)
150+
elif isinstance(input_ids, tf.Tensor):
151+
ncols = tf.reduce_prod(tf.shape(input_ids)[1:])
152+
repeats = tf.repeat(ncols, nrows)
153+
row_indices = tf.reshape(tf.repeat(tf.range(nrows), repeats), [-1, 1])
154+
row_indices = tf.cast(row_indices, tf.int64)
155+
else:
156+
raise NotImplementedError(
157+
"Cannot parse input_ids of type %s" % input_ids.__class__.__name__
158+
)
159+
160+
if num_microbatches is not None:
161+
tf.debugging.assert_equal(
162+
nrows % num_microbatches,
163+
0,
164+
"num_microbatches must divide the batch size.",
165+
)
166+
microbatch_size = tf.cast(nrows / num_microbatches, tf.int64)
167+
nrows = num_microbatches
168+
row_indices = tf.cast(
169+
tf.math.floordiv(row_indices, microbatch_size), tf.int64
170+
)
171+
# NOTE: expected values for the running example above are
172+
# row_indices = [[0], [0], [0], [1], [2], [2]]
173+
174+
flattened_indices = tf.expand_dims(tf.reshape(input_ids, [-1]), axis=-1)
175+
paired_indices = tf.concat(
176+
[tf.cast(row_indices, tf.int64), tf.cast(flattened_indices, tf.int64)],
177+
axis=1,
178+
)
179+
# NOTE: expected values for the running example above are
180+
# flattened_indices = [[1], [1], [2], [0], [2], [0]]
181+
# paired_indices = [[0, 1], [0, 1], [0, 2], [1, 0], [2, 2], [2, 0]]
182+
183+
transform = tf.cast(tf.stack([[vocab_size], [1]], axis=0), tf.int64)
184+
linearized_pair_indices = tf.reshape(
185+
tf.matmul(paired_indices, transform), (-1,)
186+
)
187+
contribution_counts_linearized_indices, _, contribution_counts_values = (
188+
tf.unique_with_counts(linearized_pair_indices)
189+
)
190+
contribution_counts_indices = tf.stack(
191+
[
192+
contribution_counts_linearized_indices // vocab_size,
193+
contribution_counts_linearized_indices % vocab_size,
194+
],
195+
axis=1,
196+
)
197+
contribution_counts = tf.sparse.SparseTensor(
198+
contribution_counts_indices,
199+
contribution_counts_values,
200+
(nrows, vocab_size),
201+
)
202+
contribution_counts = tf.sparse.reorder(contribution_counts)
203+
# NOTE: expected values for the running example above are
204+
# linearized_pair_indices = [1 1 2 3 8 6]
205+
# contribution_counts_linearized_indices = [1 2 3 8 6]
206+
# contribution_counts_indices = [[0, 1], [0, 2], [1, 0], [2, 2], [2, 0]]
207+
# contribution_counts_values = [2 1 1 1 1]
208+
209+
user_normalized_contribution_counts = (
210+
contribution_counts
211+
/ tf.sparse.reduce_sum(contribution_counts, axis=-1, keepdims=True)
212+
)
213+
contribution_histogram = tf.sparse.reduce_sum(
214+
user_normalized_contribution_counts, axis=0, output_is_sparse=True
215+
)
216+
# NOTE: expected values for the running example above are
217+
# user_normalized_contribution_counts = tf.SparseTensor(
218+
# indices=[[0, 1], [0, 2], [1, 0], [2, 0], [2, 2]],
219+
# values=[0.67, 0.33, 1., 0.5, 0.5,]
220+
# shape=(3, 3)
221+
# )
222+
# contribution_histogram = tf.SparseTensor(
223+
# indices=[[0], [1], [2]],
224+
# values=[1.5, 0.67, 0.83],
225+
# shape=(3,)
226+
# )
227+
228+
return tf.sparse.reshape(contribution_histogram, (-1,))

0 commit comments

Comments
 (0)