|
| 1 | +# Copyright 2024, The TensorFlow Authors. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +"""Compute the contribution histogram for an embedding layer.""" |
| 15 | + |
| 16 | +from typing import Optional |
| 17 | +import tensorflow as tf |
| 18 | +from tensorflow_privacy.privacy.sparsity_preserving_noise import type_aliases |
| 19 | + |
| 20 | + |
| 21 | +def embedding_layer_contribution_histogram( |
| 22 | + layer_instance: tf.keras.layers.Embedding, |
| 23 | + input_args: type_aliases.InputArgs, |
| 24 | + input_kwargs: type_aliases.InputKwargs, |
| 25 | + num_microbatches: Optional[tf.Tensor] = None, |
| 26 | +) -> dict[str, type_aliases.ContributionCountHistogramFn]: |
| 27 | + """Registry function for `tf.keras.layers.Embedding`. |
| 28 | +
|
| 29 | + Args: |
| 30 | + layer_instance: A `tf.keras.layers.Embedding` instance. |
| 31 | + input_args: A `tuple` containing the first part of `layer_instance` input. |
| 32 | + Specifically, `layer_instance(*inputs_args, **input_kwargs)` should return |
| 33 | + a valid output. |
| 34 | + input_kwargs: A `tuple` containing the second part of `layer_instance` |
| 35 | + input. Specifically, `layer_instance(*inputs_args, **input_kwargs)` should |
| 36 | + return a valid output. |
| 37 | + num_microbatches: An optional numeric value or scalar `tf.Tensor` for |
| 38 | + indicating whether and how the losses are grouped into microbatches. If |
| 39 | + not None, num_microbatches must divide the batch size. |
| 40 | +
|
| 41 | + Returns: |
| 42 | + A dict mapping the name of the trainable variable to a function with |
| 43 | + signature `(tf.IndexedSlices) -> tf.SparseTensor`. The function takes a |
| 44 | + `tf.IndexedSlices` object representing the gradient for that variable and |
| 45 | + returns a `tf.SparseTensor` representing the normalized (so that each user |
| 46 | + contributes 1) contribution counts histogram per user for each embedding |
| 47 | + vector. |
| 48 | + """ |
| 49 | + if input_kwargs: |
| 50 | + raise ValueError("Embedding layer calls should not receive kwargs.") |
| 51 | + del input_kwargs # Unused in embedding layer calls. |
| 52 | + if not input_args or len(input_args) != 1: |
| 53 | + raise ValueError("Only layer inputs of length 1 are permitted.") |
| 54 | + if hasattr(layer_instance, "sparse"): # for backwards compatibility |
| 55 | + if layer_instance.sparse: |
| 56 | + raise NotImplementedError("Sparse output tensors are not supported.") |
| 57 | + if isinstance(input_args[0], tf.SparseTensor): |
| 58 | + raise NotImplementedError("Sparse input tensors are not supported.") |
| 59 | + |
| 60 | + # Disable experimental features. |
| 61 | + if hasattr(layer_instance, "_use_one_hot_matmul"): |
| 62 | + if layer_instance._use_one_hot_matmul: # pylint: disable=protected-access |
| 63 | + raise NotImplementedError( |
| 64 | + "The experimental embedding feature " |
| 65 | + "'_use_one_hot_matmul' is not supported." |
| 66 | + ) |
| 67 | + input_ids = tf.squeeze(tf.cast(*input_args, tf.int32)) |
| 68 | + |
| 69 | + def count_contributions_fn( |
| 70 | + grad: type_aliases.SparseGradient, |
| 71 | + ) -> type_aliases.ContributionCountHistogram: |
| 72 | + return embedding_layer_contribution_histogram_fn( |
| 73 | + grad, |
| 74 | + input_ids, |
| 75 | + layer_instance.input_dim, |
| 76 | + num_microbatches, |
| 77 | + ) |
| 78 | + |
| 79 | + if ( |
| 80 | + not layer_instance.trainable_variables |
| 81 | + or len(layer_instance.trainable_variables) != 1 |
| 82 | + ): |
| 83 | + raise ValueError( |
| 84 | + "Embedding layer must have exactly one trainable variable." |
| 85 | + ) |
| 86 | + return {layer_instance.trainable_variables[0].name: count_contributions_fn} |
| 87 | + |
| 88 | + |
| 89 | +def embedding_layer_contribution_histogram_fn( |
| 90 | + grad: type_aliases.SparseGradient, |
| 91 | + input_ids: tf.Tensor, |
| 92 | + vocab_size: Optional[tf.Tensor], |
| 93 | + num_microbatches: Optional[tf.Tensor] = None, |
| 94 | +) -> type_aliases.ContributionCountHistogram: |
| 95 | + """Computes the normalized contribution counts histogram for embedding layer. |
| 96 | +
|
| 97 | + NOTE: to help understand the code, we document in the function body what the |
| 98 | + expected intermediate variables are for the below running example: |
| 99 | +
|
| 100 | + grad = None |
| 101 | + input_ids = [[1, 1, 2], [0], [2, 0]] |
| 102 | + vocab_size = 3 |
| 103 | + num_microbatches = None |
| 104 | +
|
| 105 | + For ease of reference, we also list these variables below: |
| 106 | +
|
| 107 | + row_indices = [[0], [0], [0], [1], [2], [2]] |
| 108 | + flattened_indices = [[1], [1], [2], [0], [2], [0]] |
| 109 | + paired_indices = [[0, 1], [0, 1], [0, 2], [1, 0], [2, 2], [2, 0]] |
| 110 | + linearized_pair_indices = [1 1 2 3 8 6] |
| 111 | + contribution_counts_linearized_indices = [1 2 3 8 6] |
| 112 | + contribution_counts_indices = [[0, 1], [0, 2], [1, 0], [2, 2], [2, 0]] |
| 113 | + contribution_counts_values = [2 1 1 1 1] |
| 114 | + user_normalized_contribution_counts = tf.SparseTensor( |
| 115 | + indices=[[0, 1], [0, 2], [1, 0], [2, 0], [2, 2]], |
| 116 | + values=[0.67, 0.33, 1., 0.5, 0.5,] |
| 117 | + shape=(3, 3) |
| 118 | + ) |
| 119 | + contribution_histogram = tf.SparseTensor( |
| 120 | + indices=[[0], [1], [2]], |
| 121 | + values=[1.5, 0.67, 0.83], |
| 122 | + shape=(3,) |
| 123 | + ) |
| 124 | +
|
| 125 | +
|
| 126 | + Args: |
| 127 | + grad: The gradient of the layer. (unused for embedding layer) |
| 128 | + input_ids: The input ids used to compute the embeddings. |
| 129 | + vocab_size: The vocabulary size of the embedding layer. |
| 130 | + num_microbatches: An optional numeric value or scalar `tf.Tensor` for |
| 131 | + indicating whether and how the losses are grouped into microbatches. If |
| 132 | + not None, num_microbatches must divide the batch size. |
| 133 | +
|
| 134 | + Returns: |
| 135 | + A `tf.SparseTensor` representing the normalized (so that each user |
| 136 | + contributes 1) contribution counts histogram per user for each embedding |
| 137 | + vector. |
| 138 | +
|
| 139 | + Raises: |
| 140 | + NotImplementedError: If the input_ids is not a `tf.Tensor` or |
| 141 | + `tf.RaggedTensor`. |
| 142 | + """ |
| 143 | + del grad # unused. |
| 144 | + |
| 145 | + nrows = tf.shape(input_ids)[0] |
| 146 | + if isinstance(input_ids, tf.RaggedTensor): |
| 147 | + row_indices = tf.expand_dims( |
| 148 | + input_ids.merge_dims(1, -1).value_rowids(), axis=-1 |
| 149 | + ) |
| 150 | + elif isinstance(input_ids, tf.Tensor): |
| 151 | + ncols = tf.reduce_prod(tf.shape(input_ids)[1:]) |
| 152 | + repeats = tf.repeat(ncols, nrows) |
| 153 | + row_indices = tf.reshape(tf.repeat(tf.range(nrows), repeats), [-1, 1]) |
| 154 | + row_indices = tf.cast(row_indices, tf.int64) |
| 155 | + else: |
| 156 | + raise NotImplementedError( |
| 157 | + "Cannot parse input_ids of type %s" % input_ids.__class__.__name__ |
| 158 | + ) |
| 159 | + |
| 160 | + if num_microbatches is not None: |
| 161 | + tf.debugging.assert_equal( |
| 162 | + nrows % num_microbatches, |
| 163 | + 0, |
| 164 | + "num_microbatches must divide the batch size.", |
| 165 | + ) |
| 166 | + microbatch_size = tf.cast(nrows / num_microbatches, tf.int64) |
| 167 | + nrows = num_microbatches |
| 168 | + row_indices = tf.cast( |
| 169 | + tf.math.floordiv(row_indices, microbatch_size), tf.int64 |
| 170 | + ) |
| 171 | + # NOTE: expected values for the running example above are |
| 172 | + # row_indices = [[0], [0], [0], [1], [2], [2]] |
| 173 | + |
| 174 | + flattened_indices = tf.expand_dims(tf.reshape(input_ids, [-1]), axis=-1) |
| 175 | + paired_indices = tf.concat( |
| 176 | + [tf.cast(row_indices, tf.int64), tf.cast(flattened_indices, tf.int64)], |
| 177 | + axis=1, |
| 178 | + ) |
| 179 | + # NOTE: expected values for the running example above are |
| 180 | + # flattened_indices = [[1], [1], [2], [0], [2], [0]] |
| 181 | + # paired_indices = [[0, 1], [0, 1], [0, 2], [1, 0], [2, 2], [2, 0]] |
| 182 | + |
| 183 | + transform = tf.cast(tf.stack([[vocab_size], [1]], axis=0), tf.int64) |
| 184 | + linearized_pair_indices = tf.reshape( |
| 185 | + tf.matmul(paired_indices, transform), (-1,) |
| 186 | + ) |
| 187 | + contribution_counts_linearized_indices, _, contribution_counts_values = ( |
| 188 | + tf.unique_with_counts(linearized_pair_indices) |
| 189 | + ) |
| 190 | + contribution_counts_indices = tf.stack( |
| 191 | + [ |
| 192 | + contribution_counts_linearized_indices // vocab_size, |
| 193 | + contribution_counts_linearized_indices % vocab_size, |
| 194 | + ], |
| 195 | + axis=1, |
| 196 | + ) |
| 197 | + contribution_counts = tf.sparse.SparseTensor( |
| 198 | + contribution_counts_indices, |
| 199 | + contribution_counts_values, |
| 200 | + (nrows, vocab_size), |
| 201 | + ) |
| 202 | + contribution_counts = tf.sparse.reorder(contribution_counts) |
| 203 | + # NOTE: expected values for the running example above are |
| 204 | + # linearized_pair_indices = [1 1 2 3 8 6] |
| 205 | + # contribution_counts_linearized_indices = [1 2 3 8 6] |
| 206 | + # contribution_counts_indices = [[0, 1], [0, 2], [1, 0], [2, 2], [2, 0]] |
| 207 | + # contribution_counts_values = [2 1 1 1 1] |
| 208 | + |
| 209 | + user_normalized_contribution_counts = ( |
| 210 | + contribution_counts |
| 211 | + / tf.sparse.reduce_sum(contribution_counts, axis=-1, keepdims=True) |
| 212 | + ) |
| 213 | + contribution_histogram = tf.sparse.reduce_sum( |
| 214 | + user_normalized_contribution_counts, axis=0, output_is_sparse=True |
| 215 | + ) |
| 216 | + # NOTE: expected values for the running example above are |
| 217 | + # user_normalized_contribution_counts = tf.SparseTensor( |
| 218 | + # indices=[[0, 1], [0, 2], [1, 0], [2, 0], [2, 2]], |
| 219 | + # values=[0.67, 0.33, 1., 0.5, 0.5,] |
| 220 | + # shape=(3, 3) |
| 221 | + # ) |
| 222 | + # contribution_histogram = tf.SparseTensor( |
| 223 | + # indices=[[0], [1], [2]], |
| 224 | + # values=[1.5, 0.67, 0.83], |
| 225 | + # shape=(3,) |
| 226 | + # ) |
| 227 | + |
| 228 | + return tf.sparse.reshape(contribution_histogram, (-1,)) |
0 commit comments