33# pyre-strict
44import warnings
55from enum import Enum
6- from typing import Any , Callable , cast , Dict , Iterable , List , Optional , Tuple , Union
6+ from typing import (
7+ Any ,
8+ Callable ,
9+ cast ,
10+ Dict ,
11+ Iterable ,
12+ List ,
13+ Optional ,
14+ Sequence ,
15+ Tuple ,
16+ Union ,
17+ )
718
819import matplotlib
920
@@ -74,8 +85,7 @@ def _cumulative_sum_threshold(
7485 )
7586 sorted_vals = np .sort (values .flatten ())
7687 cum_sums = np .cumsum (sorted_vals )
77- threshold_id = np .where (cum_sums >= cum_sums [- 1 ] * 0.01 * percentile )[0 ][0 ]
78- # pyre-fixme[7]: Expected `float` but got `ndarray[typing.Any, dtype[typing.Any]]`.
88+ threshold_id : int = np .where (cum_sums >= cum_sums [- 1 ] * 0.01 * percentile )[0 ][0 ]
7989 return sorted_vals [threshold_id ]
8090
8191
@@ -959,7 +969,7 @@ def __init__(
959969 self .convergence_score : float = convergence_score
960970
961971
962- def _get_color (attr : int ) -> str :
972+ def _get_color (attr : float ) -> str :
963973 # clip values to prevent CSS errors (Values should be from [-1,1])
964974 attr = max (- 1 , min (1 , attr ))
965975 if attr > 0 :
@@ -973,8 +983,7 @@ def _get_color(attr: int) -> str:
973983 return "hsl({}, {}%, {}%)" .format (hue , sat , lig )
974984
975985
976- # pyre-fixme[2]: Parameter must be annotated.
977- def format_classname (classname ) -> str :
986+ def format_classname (classname : Union [str , int ]) -> str :
978987 return '<td><text style="padding-right:2em"><b>{}</b></text></td>' .format (classname )
979988
980989
@@ -984,19 +993,24 @@ def format_special_tokens(token: str) -> str:
984993 return token
985994
986995
987- # pyre-fixme[2]: Parameter must be annotated.
988- def format_tooltip (item , text ) -> str :
996+ def format_tooltip (item : str , text : str ) -> str :
989997 return '<div class="tooltip">{item}\
990998 <span class="tooltiptext">{text}</span>\
991999 </div>' .format (
9921000 item = item , text = text
9931001 )
9941002
9951003
996- # pyre-fixme[2]: Parameter must be annotated.
997- def format_word_importances (words , importances ) -> str :
1004+ def format_word_importances (
1005+ words : Sequence [str ],
1006+ importances : Union [Sequence [float ], npt .NDArray [np .number ], Tensor ],
1007+ ) -> str :
9981008 if importances is None or len (importances ) == 0 :
9991009 return "<td></td>"
1010+ if isinstance (importances , np .ndarray ) or isinstance (importances , Tensor ):
1011+ assert len (importances .shape ) == 1 , "Expected 1D array, got {}" .format (
1012+ importances .shape
1013+ )
10001014 assert len (words ) <= len (importances )
10011015 tags = ["<td>" ]
10021016 for word , importance in zip (words , importances [: len (words )]):
0 commit comments