jgrss / geowombat

GeoWombat: Utilities for geospatial data
https://geowombat.readthedocs.io
MIT License
182 stars 10 forks source link

Regression capability in geowombat #284

Closed ritviksahajpal closed 9 months ago

ritviksahajpal commented 10 months ago

Thanks for a great package! Are you planning to include the capability to use geowombat for solving regression tasks?

thanks! Ritvik

jgrss commented 10 months ago

Hi @ritviksahajpal, thanks for the question. Are you referring to the ml module? Currently, there are no plans to add regression. However, we (@mmann1123) would be happy to look into it if there is interest. Do you have a specific idea or example in mind, or basically a general framework similar to the classifiers module?

ritviksahajpal commented 10 months ago

Thank you! Yes, referring to the ml module, a general framework was what I was thinking, I would be happy to help test it once done.

mmann1123 commented 10 months ago

This should be possible by using the following. I am giving some examples of how to do preprocessing in the pipeline as well.


with gw.config.update(ref_image=target_string):
    with gw.open(select_images, nodata=9999, stack_dim="band") as src:

          # Create a prediction stack where each variable used in regression is a band
          src.gw.save(
              "outputs/pred_stack.tif",
              compress="lzw",
              overwrite=True,  # bigtiff=True
          )
        # extract variables
        df = gw.extract(src, lu_poly, verbose=1)
        y = something
        X = something

# Create a pipeline with the preprocessor and the OLS model
pipeline_performance = Pipeline([
    ('scaler', StandardScaler()),
    ('model', LinearRegression())
])

# use sklearn to fit desired model
pipeline_performance.fit(X, y)

# predict to stack
def user_func(w, block, model):
    pred_shape = list(block.shape)
    X = block.reshape(pred_shape[0], -1).T
    pred_shape[0] = 1
    y_hat = model.predict(X)
    X_reshaped = y_hat.T.reshape(pred_shape)
    return w, X_reshaped

# create prediction image
gw.apply(
    "outputs/pred_stack.tif",
    f"outputs/final_model_OLS_prediction.tif",
    user_func,
    args=(pipeline_performance,),
    n_jobs=16,
    count=1,
)

That being said, I will eventually look into adding regression more directly as an option.