Intellipaat Back

Explore Courses Blog Tutorials Interview Questions
0 votes
2 views
in Machine Learning by (19k points)

The problem is that my train data could not be placed into RAM due to train data size. So I need a method which first builds one tree on the whole train data set, calculates residuals build another tree and so on (like gradient boosted tree do). Obviously, if I call model = xgb.train(param, batch_dtrain, 2) in some loop - it will not help, because in such case it just rebuilds the whole model for each batch.

1 Answer

0 votes
by (33.1k points)

To implement increment training for xgboost:

You can save your model after each train of the batch. Then, on successive runs, provide the xgb.train method with the filepath of the saved model.

You can split the training set into halves. Then, fit a model using the first half. Then fit two models with the second half, one model should have the additional parameter xgb_model. 

For example:

import xgboost as xgb

from sklearn.cross_validation import train_test_split as ttsplit

from sklearn.datasets import load_boston

from sklearn.metrics import mean_squared_error as mse

X = load_boston()['data']

y = load_boston()['target']

# split data into training and testing sets

# then split training set in half

X_train, X_test, y_train, y_test = ttsplit(X, y, test_size=0.1, random_state=0)

X_train_1, X_train_2, y_train_1, y_train_2 = ttsplit(X_train, y_train, test_size=0.5, random_state=0)

xg_train_1 = xgb.DMatrix(X_train_1, label=y_train_1)

xg_train_2 = xgb.DMatrix(X_train_2, label=y_train_2)

xg_test = xgb.DMatrix(X_test, label=y_test)

params = {'objective': 'reg:linear', 'verbose': False}

model_1 = xgb.train(params, xg_train_1, 30)

model_1.save_model('model_1.model')

#SECOND VERSION

model_2_v1 = xgb.train(params, xg_train_2, 30)

model_2_v2 = xgb.train(params, xg_train_2, 30, xgb_model='model_1.model')

print(mse(model_1.predict(xg_test), y_test))    

print(mse(model_2_v1.predict(xg_test), y_test)) 

print(mse(model_2_v2.predict(xg_test), y_test)) 

# 23.0475232194

# 39.6776876084

# 27.2053239482

Hope this answer helps.

31k questions

32.9k answers

500 comments

692 users

Browse Categories

...