diff --git a/esrun_data/mangrove/dataset.json b/esrun_data/mangrove/dataset.json new file mode 100644 index 00000000..a6c9cec5 --- /dev/null +++ b/esrun_data/mangrove/dataset.json @@ -0,0 +1,98 @@ +{ + "layers": { + "label": { + "type": "vector" + }, + "sentinel1": { + "band_sets": [ + { + "bands": [ + "vv", + "vh" + ], + "dtype": "float32" + } + ], + "data_source": { + "time_offset": "-180d", + "duration": "366d", + "cache_dir": "cache/planetary_computer", + "ingest": false, + "name": "rslp.satlas.data_sources.MonthlySentinel1", + "query": { + "sar:instrument_mode": { + "eq": "IW" + }, + "sar:polarizations": { + "eq": [ + "VV", + "VH" + ] + } + }, + "query_config": { + "max_matches": 12 + } + }, + "type": "raster" + }, + "sentinel2": { + "band_sets": [ + { + "bands": [ + "B02", + "B03", + "B04", + "B08" + ], + "dtype": "uint16" + }, + { + "bands": [ + "B05", + "B06", + "B07", + "B8A", + "B11", + "B12" + ], + "dtype": "uint16", + "zoom_offset": -1 + }, + { + "bands": [ + "B01", + "B09" + ], + "dtype": "uint16", + "zoom_offset": -2 + } + ], + "data_source": { + "time_offset": "-180d", + "duration": "366d", + "cache_dir": "cache/planetary_computer", + "harmonize": true, + "ingest": false, + "max_cloud_cover": 50, + "name": "rslp.satlas.data_sources.MonthlyAzureSentinel2", + "query_config": { + "max_matches": 12 + }, + "sort_by": "eo:cloud_cover" + }, + "type": "raster" + }, + "output": { + "band_sets": [ + { + "bands": [ + "output" + ], + "dtype": "uint8" + } + ], + "type": "raster" + } + } + } \ No newline at end of file diff --git a/esrun_data/mangrove/esrun.yaml b/esrun_data/mangrove/esrun.yaml new file mode 100644 index 00000000..d744280a --- /dev/null +++ b/esrun_data/mangrove/esrun.yaml @@ -0,0 +1,38 @@ +partition_strategies: + partition_request_geometry: + class_path: esrun.runner.tools.partitioners.grid_partitioner.GridPartitioner + init_args: + grid_size: .1 # Lat lon degrees, a random guess not sure what is really reasonable + output_projection: + class_path: rslearn.utils.geometry.Projection + init_args: + crs: EPSG:3857 + x_resolution: 10 + y_resolution: -10 + use_utm: true + + prepare_window_geometries: + class_path: esrun.runner.tools.partitioners.grid_partitioner.GridPartitioner + init_args: + grid_size: 32 # 32 based on the window creation script yawen made + output_projection: + class_path: rslearn.utils.geometry.Projection + init_args: + crs: EPSG:3857 + x_resolution: 10 + y_resolution: -10 + use_utm: true + +postprocessing_strategies: + process_dataset: + class_path: esrun.runner.tools.postprocessors.combine_geotiff.CombineGeotiff + + process_partition: + class_path: esrun.runner.tools.postprocessors.combine_geotiff.CombineGeotiff + + process_window: + class_path: esrun.runner.tools.postprocessors.noop_raster.NoopRaster + + +inference_results_config: + data_type: RASTER diff --git a/esrun_data/mangrove/model.yaml b/esrun_data/mangrove/model.yaml new file mode 100644 index 00000000..eff0e079 --- /dev/null +++ b/esrun_data/mangrove/model.yaml @@ -0,0 +1,457 @@ +# lightning.pytorch==2.5.1.post0 +seed_everything: 0 +model: + class_path: rslearn.train.lightning_module.RslearnLightningModule + init_args: + model: + class_path: rslearn.models.multitask.MultiTaskModel + init_args: + encoder: + - class_path: rslp.helios.model.Helios + init_args: + checkpoint_path: ${EXTRA_FILES_PATH}/step300000 + selector: + - encoder + forward_kwargs: + patch_size: 2 + random_initialization: true + autocast_dtype: bfloat16 + decoders: + mangrove_classification: + - class_path: rslp.crop.kenya_nandi.train.SegmentationPoolingDecoder + init_args: + in_channels: 768 + out_channels: 4 + num_conv_layers: 0 + num_fc_layers: 0 + conv_channels: 128 + fc_channels: 512 + - class_path: rslearn.train.tasks.segmentation.SegmentationHead + lazy_decode: false + loss_weights: null + optimizer: null + scheduler: null + visualize_dir: null + metrics_file: null + restore_config: null + print_parameters: false + print_model: false + lr: 0.0001 + plateau: true + plateau_factor: 0.2 + plateau_patience: 2 + plateau_min_lr: 0.0 + plateau_cooldown: 10 +data: + class_path: rslearn.train.data_module.RslearnDataModule + init_args: + inputs: + label: + class_path: rslearn.train.dataset.DataInput + init_args: + data_type: raster + layers: + - label_raster + bands: + - label + required: true + passthrough: false + is_target: true + dtype: INT32 + load_all_layers: false + load_all_item_groups: false + sentinel2_l2a: + class_path: rslearn.train.dataset.DataInput + init_args: + data_type: raster + layers: + - sentinel2 + bands: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + - B01 + - B09 + required: true + passthrough: true + is_target: false + dtype: FLOAT32 + load_all_layers: true + load_all_item_groups: true + task: + class_path: rslearn.train.tasks.multi_task.MultiTask + init_args: + tasks: + mangrove_classification: + class_path: rslearn.train.tasks.segmentation.SegmentationTask + init_args: + num_classes: 4 + colors: # what is this? + - - 255 + - 0 + - 0 + - - 0 + - 255 + - 0 + - - 0 + - 0 + - 255 + - - 255 + - 255 + - 0 + - - 0 + - 255 + - 255 + - - 255 + - 0 + - 255 + - - 0 + - 128 + - 0 + - - 255 + - 160 + - 122 + - - 139 + - 69 + - 19 + - - 128 + - 128 + - 128 + - - 255 + - 255 + - 255 + - - 143 + - 188 + - 143 + - - 95 + - 158 + - 160 + - - 255 + - 200 + - 0 + - - 128 + - 0 + - 0 + zero_is_invalid: true + enable_accuracy_metric: true + enable_miou_metric: false + enable_f1_metric: false + f1_metric_thresholds: + - - 0.5 + metric_kwargs: + average: micro + miou_metric_kwargs: {} + prob_scales: null + other_metrics: + mangrove_precision: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassPrecision + init_args: + num_classes: 4 + top_k: 1 + average: null + multidim_average: global + ignore_index: null + validate_args: true + zero_division: 0 + pass_probabilities: true + class_idx: 1 + mangrove_recall: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassRecall + init_args: + num_classes: 4 + top_k: 1 + average: null + multidim_average: global + ignore_index: null + validate_args: true + zero_division: 0 + pass_probabilities: true + class_idx: 1 + other_precision: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassPrecision + init_args: + num_classes: 4 + top_k: 1 + average: null + multidim_average: global + ignore_index: null + validate_args: true + zero_division: 0 + pass_probabilities: true + class_idx: 3 + other_recall: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassRecall + init_args: + num_classes: 4 + top_k: 1 + average: null + multidim_average: global + ignore_index: null + validate_args: true + zero_division: 0 + pass_probabilities: true + class_idx: 3 + water_precision: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassPrecision + init_args: + num_classes: 4 + top_k: 1 + average: null + multidim_average: global + ignore_index: null + validate_args: true + zero_division: 0 + pass_probabilities: true + class_idx: 2 + water_recall: + class_path: rslearn.train.tasks.segmentation.SegmentationMetric + init_args: + metric: + class_path: torchmetrics.classification.MulticlassRecall + init_args: + num_classes: 4 + top_k: 1 + average: null + multidim_average: global + ignore_index: null + validate_args: true + zero_division: 0 + pass_probabilities: true + class_idx: 2 + image_bands: + - 0 + - 1 + - 2 + remap_values: null + input_mapping: + mangrove_classification: + label: targets + path: ${DATASET_PATH} + path_options: {} + batch_size: 128 # Modified for testing + num_workers: ${NUM_WORKERS} + init_workers: 0 + default_config: + class_path: rslearn.train.dataset.SplitConfig + init_args: + groups: null + names: null + tags: null + num_samples: null + num_patches: null + transforms: + - class_path: rslearn.train.transforms.pad.Pad + init_args: + size: 2 + mode: center + image_selectors: + - sentinel2_l2a + - target/mangrove_classification/classes + - target/mangrove_classification/valid + box_selectors: [] + - class_path: rslp.helios.norm.HeliosNormalize + init_args: + band_names: + sentinel2_l2a: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + - B01 + - B09 + std_multiplier: 2.0 + sampler: null + patch_size: null + overlap_ratio: null + load_all_patches: null + skip_targets: null + train_config: + class_path: rslearn.train.dataset.SplitConfig + init_args: + groups: + - sample_100K + names: null + tags: + split: train + num_samples: 10000 + num_patches: null + transforms: null + sampler: null + patch_size: null + overlap_ratio: null + load_all_patches: null + skip_targets: null + val_config: + class_path: rslearn.train.dataset.SplitConfig + init_args: + groups: + - sample_100K + names: null + tags: + split: val + num_samples: null + num_patches: null + transforms: null + sampler: null + patch_size: null + overlap_ratio: null + load_all_patches: null + skip_targets: null + test_config: + class_path: rslearn.train.dataset.SplitConfig + init_args: + groups: + - sample_100K + names: null + tags: + split: val + num_samples: null + num_patches: null + transforms: null + sampler: null + patch_size: null + overlap_ratio: null + load_all_patches: null + skip_targets: null + predict_config: + patch_size: 2 #I think this is actually the window size for inference processing + overlap_ratio: null + load_all_patches: true + skip_targets: true + transforms: + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + sentinel2_l2a: [] + output_selector: sentinel2_l2a + # I feel like we really don;t want this + - class_path: rslp.helios.norm.HeliosNormalize + init_args: + band_names: + sentinel2_l2a: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + - B01 + - B09 + - class_path: rslearn.train.transforms.pad.Pad + init_args: + size: 2 + mode: "center" + image_selectors: ["sentinel2_l2a"] + name: null + retries: 0 +trainer: + profiler: simple + accelerator: auto + strategy: + class_path: lightning.pytorch.strategies.DDPStrategy + init_args: + find_unused_parameters: true + devices: auto + num_nodes: 1 + callbacks: + - class_path: rslearn.train.prediction_writer.RslearnWriter + init_args: + path: ${DATASET_PATH} + output_layer: ${PREDICTION_OUTPUT_LAYER} + selector: ["mangrove_classification"] + merger: null + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + dirpath: gs://rslearn-eai/projects/2025_09_15_mangrove_classification/mangrove_classification_segment_helios_base_S2_ts_ws2_ps2/checkpoints + filename: null + monitor: val_loss + verbose: false + save_last: true + save_top_k: 1 + save_weights_only: false + mode: min + auto_insert_metric_name: true + every_n_train_steps: null + train_time_interval: null + every_n_epochs: null + save_on_train_epoch_end: null + enable_version_counter: true +# lighning wants these other + - class_path: rslearn.train.callbacks.freeze_unfreeze.FreezeUnfreeze + init_args: + module_selector: + - model + - encoder + - 0 + unfreeze_at_epoch: 20 + unfreeze_lr_factor: 10.0 + fast_dev_run: false + max_epochs: 100 + min_epochs: null + max_steps: -1 + min_steps: null + max_time: null + limit_train_batches: null + limit_val_batches: null + limit_test_batches: null + limit_predict_batches: null + overfit_batches: 0.0 + val_check_interval: null + check_val_every_n_epoch: 1 + num_sanity_val_steps: null + log_every_n_steps: null + enable_checkpointing: null + enable_progress_bar: null + enable_model_summary: null + accumulate_grad_batches: 1 + gradient_clip_val: null + gradient_clip_algorithm: null + deterministic: null + benchmark: null + inference_mode: true + use_distributed_sampler: false + profiler: null + detect_anomaly: false + barebones: false + plugins: null + sync_batchnorm: false + reload_dataloaders_every_n_epochs: 0 + default_root_dir: null + model_registry: null +# rslp_project: ${WANDB_PROJECT} +# rslp_experiment: ${WANDB_NAME} +# ${EXTRA_FILES_PATH}: Path to the extra pretrained model/data preprocessing config files +# ${TRAINER_DATA_PATH}: Path to the rslearn trainer data directory. For intermediate checkpoints, trainer state, etc. +# ${WANDB_ENTITY}: wandb entity for the trainer to log metrics to \ No newline at end of file diff --git a/esrun_data/mangrove/prediction_request_geometry.geojson b/esrun_data/mangrove/prediction_request_geometry.geojson new file mode 100644 index 00000000..cf94439d --- /dev/null +++ b/esrun_data/mangrove/prediction_request_geometry.geojson @@ -0,0 +1,53 @@ +{ + "type": "FeatureCollection", + "name": "region_of_interests_for_testing", + "crs": { + "type": "name", + "properties": { + "name": "urn:ogc:def:crs:OGC:1.3:CRS84" + } + }, + "features": [ + { + "type": "Feature", + "properties": { + "fid": 136, + "MinX": 55.0, + "MaxX": 56.0, + "MinY": 26.0, + "MaxY": 27.0, + "tile_name": "N27E055", + "gmw_tile_name": "GMW_N27E055", + "es_start_time": "2018-01-01T00:00:00Z", + "es_end_time": "2018-03-01T23:59:59Z" + }, + "geometry": { + "type": "Polygon", + "coordinates": [ + [ + [ + 55.0, + 27.0 + ], + [ + 55.05, + 27.0 + ], + [ + 55.05, + 26.95 + ], + [ + 55.0, + 26.95 + ], + [ + 55.0, + 27.0 + ] + ] + ] + } + } + ] + } \ No newline at end of file diff --git a/esrun_data/sample/README.md b/esrun_data/sample/README.md index cc98d6a3..5afead5a 100644 --- a/esrun_data/sample/README.md +++ b/esrun_data/sample/README.md @@ -133,6 +133,9 @@ There are 3 different stages to postprocessing: ### Samples + +TODO: Add examples with the new rslp cli for this + #### Run a pipeline end-to-end The simplest way to run a pipeline is to use the `esrun-local-predict` CLI command. This command will run the entire pipeline end-to-end including partitioning, dataset building, inference, post-processing, and combining the final outputs. diff --git a/rslp/esrun/esrun.py b/rslp/esrun/esrun.py index e5b17766..e69e8f58 100644 --- a/rslp/esrun/esrun.py +++ b/rslp/esrun/esrun.py @@ -1,6 +1,7 @@ """Run EsPredictRunner inference pipeline.""" import hashlib +import logging import shutil import tempfile from enum import StrEnum @@ -9,6 +10,8 @@ import fsspec from esrun.runner.local.fine_tune_runner import EsFineTuneRunner from esrun.runner.local.predict_runner import EsPredictRunner +from esrun.shared.models.task_results import InferenceResultsDataType +from esrun.shared.tools.logger import configure_logging from upath import UPath from rslp.log_utils import get_logger @@ -69,12 +72,16 @@ def esrun(config_path: Path, scratch_path: Path, checkpoint_path: str) -> None: scratch_path: directory to use for scratch space. checkpoint_path: path to the model checkpoint. """ + # Configure esrun logging before creating the runner + configure_logging(log_level=logging.INFO) + runner = EsPredictRunner( # ESRun does not work with relative path, so make sure to convert to absolute here. project_path=config_path.absolute(), scratch_path=scratch_path, checkpoint_path=get_local_checkpoint(UPath(checkpoint_path)), ) + logger.info("Partitioning...") partitions = runner.partition() logger.info(f"Got {len(partitions)} partitions") @@ -86,6 +93,7 @@ def esrun(config_path: Path, scratch_path: Path, checkpoint_path: str) -> None: runner.run_inference(partition_id) logger.info(f"Postprocessing for partition {partition_id}") runner.postprocess(partition_id) + break logger.info("Combining across partitions") runner.combine(partitions) @@ -109,6 +117,7 @@ def one_stage( checkpoint_path: str, stage: EsrunStage, partition_id: str | None = None, + inference_results_data_type: InferenceResultsDataType = InferenceResultsDataType.RASTER, ) -> None: """Run EsPredictRunner inference pipeline. @@ -124,6 +133,9 @@ def one_stage( if stage == EsrunStage.COMBINE and partition_id is not None: raise ValueError("partition_id cannot be set for COMBINE stage") + # Configure esrun logging before creating the runner + configure_logging(log_level=logging.INFO) + runner = EsPredictRunner( # ESRun does not work with relative path, so make sure to convert to absolute here. project_path=config_path, @@ -143,6 +155,7 @@ def one_stage( if stage == EsrunStage.RUN_INFERENCE: fn = runner.run_inference elif stage == EsrunStage.POSTPROCESS: + runner.inference_results_data_type = inference_results_data_type fn = runner.postprocess else: assert False diff --git a/rslp/mangrove/create_windows_for_classification.py b/rslp/mangrove/create_windows_for_classification.py index 0fe6d49b..df5c929c 100644 --- a/rslp/mangrove/create_windows_for_classification.py +++ b/rslp/mangrove/create_windows_for_classification.py @@ -25,6 +25,7 @@ START_TIME = datetime(2020, 6, 15, tzinfo=timezone.utc) END_TIME = datetime(2020, 7, 15, tzinfo=timezone.utc) +# For every latlon we get a window of size 32 around the data def create_window( csv_row: pd.Series, ds_path: UPath, window_size: int, group_name: str