Skip to content

Commit b21e658

Browse files
khotilovtqchen
authored andcommitted
[R-package] JSON dump format and a couple of bugfixes (#1855)
* [R-package] JSON tree dump interface * [R-package] precision bugfix in xgb.attributes * [R-package] bugfix for cb.early.stop called from xgb.cv * [R-package] a bit more clarity on labels checking in xgb.cv * [R-package] test JSON dump for gblinear as well * whitespace lint
1 parent 0268ded commit b21e658

File tree

10 files changed

+72
-22
lines changed

10 files changed

+72
-22
lines changed

R-package/R/callbacks.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ cb.reset.parameters <- function(new_params) {
229229
xgb.parameters(env$bst$handle) <- pars
230230
} else {
231231
for (fd in env$bst_folds)
232-
xgb.parameters(fd$bst$handle) <- pars
232+
xgb.parameters(fd$bst) <- pars
233233
}
234234
}
235235
attr(callback, 'is_pre_iteration') <- TRUE

R-package/R/xgb.Booster.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,7 @@ xgb.attributes <- function(object) {
339339
# Q: should we warn a user about non-scalar elements?
340340
a <- lapply(a, function(x) {
341341
if (is.null(x)) return(NULL)
342-
if (is.numeric(value[1])) {
342+
if (is.numeric(x[1])) {
343343
format(x[1], digits = 17)
344344
} else {
345345
as.character(x[1])

R-package/R/xgb.cv.R

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@
1616
#'
1717
#' See \code{\link{xgb.train}} for further details.
1818
#' See also demo/ for walkthrough example in R.
19-
#' @param data takes an \code{xgb.DMatrix} or \code{Matrix} as the input.
19+
#' @param data takes an \code{xgb.DMatrix}, \code{matrix}, or \code{dgCMatrix} as the input.
2020
#' @param nrounds the max number of iterations
2121
#' @param nfold the original dataset is randomly partitioned into \code{nfold} equal size subsamples.
22-
#' @param label vector of response values. Should be provided only when data is \code{DMatrix}.
22+
#' @param label vector of response values. Should be provided only when data is an R-matrix.
2323
#' @param missing is only used when input is a dense matrix. By default is set to NA, which means
2424
#' that NA values should be considered as 'missing' by the algorithm.
2525
#' Sometimes, 0 or other extreme value might be used to represent missing values.
@@ -129,10 +129,9 @@ xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL, missing =
129129
#if (is.null(params[['eval_metric']]) && is.null(feval))
130130
# stop("Either 'eval_metric' or 'feval' must be provided for CV")
131131

132-
# Labels
133-
if (class(data) == 'xgb.DMatrix')
134-
labels <- getinfo(data, 'label')
135-
if (is.null(labels))
132+
# Check the labels
133+
if ( (class(data) == 'xgb.DMatrix' && is.null(getinfo(data, 'label'))) ||
134+
(class(data) != 'xgb.DMatrix' && is.null(label)))
136135
stop("Labels must be provided for CV either through xgb.DMatrix, or through 'label=' when 'data' is matrix")
137136

138137
# CV folds

R-package/R/xgb.dump.R

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#' When this option is on, the model dump comes with two additional statistics:
1515
#' gain is the approximate loss function gain we get in each split;
1616
#' cover is the sum of second order gradient in each node.
17+
#' @param dump_fomat either 'text' or 'json' format could be specified.
1718
#' @param ... currently not used
1819
#'
1920
#' @return
@@ -30,10 +31,15 @@
3031
#' xgb.dump(bst, 'xgb.model.dump', with_stats = TRUE)
3132
#'
3233
#' # print the model without saving it to a file
33-
#' print(xgb.dump(bst))
34+
#' print(xgb.dump(bst, with_stats = TRUE))
35+
#'
36+
#' # print in JSON format:
37+
#' cat(xgb.dump(bst, with_stats = TRUE, dump_format='json'))
38+
#'
3439
#' @export
35-
xgb.dump <- function(model = NULL, fname = NULL, fmap = "", with_stats=FALSE, ...) {
40+
xgb.dump <- function(model = NULL, fname = NULL, fmap = "", with_stats=FALSE, dump_format = c("text", "json"), ...) {
3641
check.deprecation(...)
42+
dump_format <- match.arg(dump_format)
3743
if (class(model) != "xgb.Booster")
3844
stop("model: argument must be of type xgb.Booster")
3945
if (!(class(fname) %in% c("character", "NULL") && length(fname) <= 1))
@@ -42,12 +48,15 @@ xgb.dump <- function(model = NULL, fname = NULL, fmap = "", with_stats=FALSE, ..
4248
stop("fmap: argument must be of type character (when provided)")
4349

4450
model <- xgb.Booster.check(model)
45-
model_dump <- .Call("XGBoosterDumpModel_R", model$handle, fmap, as.integer(with_stats), PACKAGE = "xgboost")
51+
model_dump <- .Call("XGBoosterDumpModel_R", model$handle, fmap, as.integer(with_stats),
52+
as.character(dump_format), PACKAGE = "xgboost")
4653

4754
if (is.null(fname))
4855
model_dump <- stri_replace_all_regex(model_dump, '\t', '')
4956

50-
model_dump <- unlist(stri_split_regex(model_dump, '\n'))
57+
if (dump_format == "text")
58+
model_dump <- unlist(stri_split_regex(model_dump, '\n'))
59+
5160
model_dump <- grep('^\\s*$', model_dump, invert = TRUE, value = TRUE)
5261

5362
if (is.null(fname)) {

R-package/man/xgb.cv.Rd

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

R-package/man/xgb.dump.Rd

Lines changed: 9 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

R-package/src/xgboost_R.cc

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -350,20 +350,37 @@ SEXP XGBoosterModelToRaw_R(SEXP handle) {
350350
return ret;
351351
}
352352

353-
SEXP XGBoosterDumpModel_R(SEXP handle, SEXP fmap, SEXP with_stats) {
353+
SEXP XGBoosterDumpModel_R(SEXP handle, SEXP fmap, SEXP with_stats, SEXP dump_format) {
354354
SEXP out;
355355
R_API_BEGIN();
356356
bst_ulong olen;
357357
const char **res;
358-
CHECK_CALL(XGBoosterDumpModel(R_ExternalPtrAddr(handle),
358+
const char *fmt = CHAR(asChar(dump_format));
359+
CHECK_CALL(XGBoosterDumpModelEx(R_ExternalPtrAddr(handle),
359360
CHAR(asChar(fmap)),
360361
asInteger(with_stats),
362+
fmt,
361363
&olen, &res));
362364
out = PROTECT(allocVector(STRSXP, olen));
363-
for (size_t i = 0; i < olen; ++i) {
365+
if (!strcmp("json", fmt)) {
364366
std::stringstream stream;
365-
stream << "booster[" << i <<"]\n" << res[i];
366-
SET_STRING_ELT(out, i, mkChar(stream.str().c_str()));
367+
stream << "[\n";
368+
for (size_t i = 0; i < olen; ++i) {
369+
stream << res[i];
370+
if (i < olen - 1) {
371+
stream << ",\n";
372+
} else {
373+
stream << "\n";
374+
}
375+
}
376+
stream << "]";
377+
SET_STRING_ELT(out, 0, mkChar(stream.str().c_str()));
378+
} else {
379+
for (size_t i = 0; i < olen; ++i) {
380+
std::stringstream stream;
381+
stream << "booster[" << i <<"]\n" << res[i];
382+
SET_STRING_ELT(out, i, mkChar(stream.str().c_str()));
383+
}
367384
}
368385
R_API_END();
369386
UNPROTECT(1);

R-package/src/xgboost_R.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,8 +185,9 @@ XGB_DLL SEXP XGBoosterModelToRaw_R(SEXP handle);
185185
* \param handle handle
186186
* \param fmap name to fmap can be empty string
187187
* \param with_stats whether dump statistics of splits
188+
* \param dump_format the format to dump the model in
188189
*/
189-
XGB_DLL SEXP XGBoosterDumpModel_R(SEXP handle, SEXP fmap, SEXP with_stats);
190+
XGB_DLL SEXP XGBoosterDumpModel_R(SEXP handle, SEXP fmap, SEXP with_stats, SEXP dump_format);
190191

191192
/*!
192193
* \brief get learner attribute value

R-package/tests/testthat/test_callbacks.R

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,11 @@ test_that("cb.reset.parameters works as expected", {
147147
bst4 <- xgb.train(param, dtrain, nrounds = 2, watchlist,
148148
callbacks = list(cb.reset.parameters(my_par)))
149149
, NA) # NA = no error
150+
# CV works as well
151+
expect_error(
152+
bst4 <- xgb.cv(param, dtrain, nfold = 2, nrounds = 2,
153+
callbacks = list(cb.reset.parameters(my_par)))
154+
, NA) # NA = no error
150155

151156
# expect no learning with 0 learning rate
152157
my_par <- list(eta = c(0., 0.))

R-package/tests/testthat/test_helpers.R

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,11 @@ test_that("xgb.dump works", {
2727
expect_true(xgb.dump(bst.Tree, 'xgb.model.dump', with_stats = T))
2828
expect_true(file.exists('xgb.model.dump'))
2929
expect_gt(file.size('xgb.model.dump'), 8000)
30+
31+
# JSON format
32+
dmp <- xgb.dump(bst.Tree, dump_format = "json")
33+
expect_length(dmp, 1)
34+
expect_length(grep('nodeid', strsplit(dmp, '\n')[[1]]), 162)
3035
})
3136

3237
test_that("xgb.dump works for gblinear", {
@@ -38,6 +43,11 @@ test_that("xgb.dump works for gblinear", {
3843
d.sp <- xgb.dump(bst.GLM.sp)
3944
expect_length(d.sp, 14)
4045
expect_gt(sum(d.sp == "0"), 0)
46+
47+
# JSON format
48+
dmp <- xgb.dump(bst.GLM.sp, dump_format = "json")
49+
expect_length(dmp, 1)
50+
expect_length(grep('\\d', strsplit(dmp, '\n')[[1]]), 11)
4151
})
4252

4353
test_that("xgb-attribute functionality", {
@@ -83,6 +93,8 @@ test_that("xgb-attribute numeric precision", {
8393
for (x in X) {
8494
xgb.attr(bst.Tree, "x") <- x
8595
expect_identical(as.numeric(xgb.attr(bst.Tree, "x")), x)
96+
xgb.attributes(bst.Tree) <- list(a = "A", b = x)
97+
expect_identical(as.numeric(xgb.attr(bst.Tree, "b")), x)
8698
}
8799
})
88100

0 commit comments

Comments
 (0)