google / Xee

An Xarray extension for Google Earth Engine
https://xee.rtfd.io
Apache License 2.0
257 stars 29 forks source link

Estimating EECU hours #104

Open ljstrnadiii opened 1 year ago

ljstrnadiii commented 1 year ago

This is definitely a nice-to-have, but I am wondering if there is a reliable way to estimate the number of eecu-hours. Considering we pay for these hours, it would be nice to avoid the "oops just spent 1k+$ on an experimental dataset" issue.

We currently export data to fixed-size tiles and take a sample of tiles, run an export task to cloud storage, and get summary stats of the eecu-hours with ee.data.getOperation(f"projects/earthengine-legacy/operations/{task_id}"). This allows us to roughly estimate the cost of "ingest".

I think this is hard in the general case, but maybe we could build a recipe to sample/slice in time/x/y/dims in order to build an estimate of eecu-cost? In reality, this would be a nice-to-have function on image collections themselves, but I am guessing ee.data.getPixels, export to cloud storage, or other options vary in eecu-time. Thoughts?

ljstrnadiii commented 1 year ago

It looks like one possibility is to annotate earth engine ops with Workload Tags, then we would need to query metrics explorer for earthengine.googleapis.com/project/cpu/usage_time. This might assume a bit much and is likely beyond the scope of this package. @alxmrs feel free to close if you think it is beyond the scope.

ljstrnadiii commented 1 year ago

To sketch things out, I am estimating with something like:

import datetime
import time
from dataclasses import dataclass
from math import ceil, floor
from random import shuffle

import xarray as xr
from dataclasses_json import dataclass_json
from google.cloud import monitoring_v3
from loguru import logger
import numpy as np
from itertools import product
import ee

EE_USD_PER_EECU_HOUR = ...
PROJECT = ...

@dataclass_json
@dataclass
class CostEstimate:
    cost_usd: float
    approximate_cost_per_chunk: float
    eecu_hours_per_chunk: float
    number_chunks: int
    workload_tag: str

def _get_dataset_chunk_slices(
    dset: xr.Dataset | xr.DataArray,
) -> list[dict[str, slice]]:
    dims = [str(d) for d in dset.dims]
    start_steps = {}
    for dim in dims:
        dim_starts = np.cumsum([0] + list(dset.chunks[dim])[:-1])
        dim_start_step = zip(dim_starts, dset.chunksizes[dim])
        start_steps[dim] = dim_start_step
    dim_start_steps = [start_steps[dim] for dim in dims]
    chunk_slices = product(*dim_start_steps)
    all_slices = []
    for dim_sstep in chunk_slices:
        slices = {}
        for d, dss in zip(dims, dim_sstep):
            d_start, d_step = dss
            slices[d] = slice(d_start, d_start + d_step)
        all_slices.append(slices)
    return all_slices

def estimate_xee_dataset_cost(
    dset: xr.Dataset,
    workload_tag: str,
    sample_size: int = 3,
    stable_iterations: int = 10,
) -> CostEstimate:
    start_time = datetime.datetime.now()
    slices = _get_dataset_chunk_slices(dset)
    shuffle(slices)
    means = []
    for slc in slices[:sample_size]:
        means.append(dset.isel(slc).mean())

    logger.info(f"computing {sample_size} chunks to estimate cost...")
    _ = xr.concat(means, dim="slice").compute()

    logger.info("estimating cost by polling...")
    eecu_hours = _estimate_eecu_hour_from_workload_tag(
        start_time, workload_tag, stable_iterations=stable_iterations
    )
    cost_per_chunk = eecu_hours / sample_size * EE_USD_PER_EECU_HOUR
    return CostEstimate(
        cost_per_chunk * len(slices),
        cost_per_chunk,
        eecu_hours / sample_size,
        len(slices),
        workload_tag,
    )

def _estimate_eecu_hour_from_workload_tag(
    start_time: datetime.datetime,
    workload_tag: str,
    poll_interval: int = 10,
    stable_iterations: int = 10,  # usually takes 7 iterations to stabilize.
) -> float:
    """Get the total eecu seconds for the specified workload tag."""

    running_value = -1
    bumps = 0
    while bumps < stable_iterations:
        try:
            current_value = _poll_eecu_seconds(
                start_time, datetime.datetime.now(), workload_tag
            )
            if current_value >= running_value:
                running_value = current_value
                bumps += 1
        except ValueError as e:
            logger.info(f"exception found (likely because there is no metric yet): {e}")
        logger.info(f"current estimate: {running_value} for workload: {workload_tag}")
        time.sleep(poll_interval)
    return running_value

def _poll_eecu_seconds(
    start_time: datetime.datetime, end_time: datetime.datetime, workload_tag: str
):
    metric_type = "earthengine.googleapis.com/project/cpu/usage_time"
    resource_type = "earthengine.googleapis.com/Project"
    filter_str = f'metric.type="{metric_type}" AND resource.type="{resource_type}"'

    interval = monitoring_v3.TimeInterval(
        start_time={"seconds": floor(int(start_time.timestamp()))},
        end_time={"seconds": ceil(int(end_time.timestamp()))},
    )

    client = monitoring_v3.MetricServiceClient()
    results = client.list_time_series(
        name=PROJECT,
        filter=filter_str,
        interval=interval,
        view=monitoring_v3.ListTimeSeriesRequest.TimeSeriesView.FULL,
    )
    relevant_metrics = list(
        filter(lambda t: t.metric.labels["workload_tag"] == workload_tag, results)
    )
    if len(relevant_metrics) != 1:
        raise ValueError(f"workload_tag {workload_tag} not found")

    total_seconds = sum(p.value.double_value for p in relevant_metrics[0].points)
    return total_seconds / 3600

Then we build a dataset with a workload tag:

workload_tag = "uuid-random-str"
with ee.data.workloadTagContext(workload_tag):
   dset = xr.open_dataset(..., engine='ee')
   cost_estimate = estimate_xee_dataset_cost(dset, workload_tag)

This is a sketch and the polling mechanism is a wip, but the main thing to note is that it takes some time for eecu usage to update ~70 seconds on average for me. It tends to bump up twice as we poll.

jdbcode commented 1 year ago

@ljstrnadiii I've argued for Earth Engine cost estimating tools for a long time. The engineers are reluctant given the variability and uncertainty - that it is essentially a Turing halting problem where it is unsolvable in principle in the general case. Some customer engineers, however, have worked with clients in their fairly specific, constrained real-world cases to extrapolate in a similar way to what you've shown. I could imagine Earth Engine supporting similar sampling tools in the future.

Cost controls are also something that may be coming in the future - systems to avoid the "oops I just spent 1k+$ on an experimental dataset" - possibly something like setting a cost or compute limit that triggers request/task termination.

I'll share your post with the people looking into these topics so they are aware of the Xee case. Also flagging to consider adding your code or similar to the Xee examples / docs, if you're okay with it.

ljstrnadiii commented 1 year ago

@jdbcode thanks a ton 🙏 .

I figured if it was an easy problem there would be a solution. Using workload tags should be a fine solution for now and will at least give us a sense of when our sampling approach is way off.

Happy to add to the docs, but to be completely honest, I am not very confident I am using the monitoring api correctly--I don't know it enough to understand alignment and sampling frequency implications. It also feels like more of an earth engine specific thing. Happy to add if people show interest and have useful feedback on the approach.