Aggregation

There are three types of aggregate functions: tile aggregate, DataFrame aggregate, and element-wise local aggregate. In the tile aggregate functions, we are computing a statistical summary per row of a tile column in a DataFrame. In the DataFrame aggregate functions, we are computing statistical summaries over all of the cell values and across all of the rows in the DataFrame or group. In the element-wise local aggregate functions, we are computing the element-wise statistical summary across a DataFrame or group of tiles.

Tile Mean Example

We can illustrate aggregate differences by computing an aggregate mean. First, we create a sample DataFrame of 2 tiles. The tiles will contain normally distributed cell values with the first row’s mean at 1.0 and the second row’s mean at 3.0. For details on use of the Tile class see the page on numpy interoperability.

from pyrasterframes.rf_types import Tile, CellType

t1 = Tile(1 + 0.1 * np.random.randn(5,5), CellType('float64raw'))

t1.cells  # display the array in the Tile
array([[1.072, 0.867, 0.864, 1.039, 0.978],
       [1.068, 0.84 , 1.023, 0.953, 0.99 ],
       [0.828, 1.149, 0.864, 0.854, 0.96 ],
       [1.036, 0.987, 1.087, 1.164, 0.921],
       [1.078, 0.846, 0.967, 1.003, 1.109]])
t5 = Tile(5 + 0.1 * np.random.randn(5,5), CellType('float64raw'))
t5.cells
array([[5.099, 4.895, 4.853, 5.137, 4.981],
       [4.957, 4.907, 5.128, 5.022, 4.888],
       [4.976, 4.881, 5.009, 5.039, 5.055],
       [4.968, 4.963, 4.933, 4.983, 4.886],
       [4.843, 5.042, 5.057, 4.972, 5.002]])

Create a Spark DataFrame from the Tile objects.

import pyspark.sql.functions as F
from pyspark.sql import Row

rf = spark.createDataFrame([
    Row(id=1, tile=t1),
    Row(id=2, tile=t5)
]).orderBy('id')

We use the rf_tile_mean function to compute the tile aggregate mean of cells in each row of column tile. The mean of each tile is computed separately, so the first mean is about 1.0 and the second mean is about 3.0. Notice that the number of rows in the DataFrame is the same before and after the aggregation.

rf.select(F.col('id'), rf_tile_mean(F.col('tile')))
id rf_tile_mean(tile)
1 0.9818299200833603
2 4.979021985349441

We use the rf_agg_mean function to compute the DataFrame aggregate, which averages values across the fifty cells in two rows. Note that only a single row is returned since the average is computed over the full DataFrame.

rf.agg(rf_agg_mean(F.col('tile')))
rf_agg_mean(tile)
2.9804259527164003

We use the rf_agg_local_mean function to compute the element-wise local aggregate mean across the two rows. For this aggregation, we are computing the mean of one value of 1.0 and one value of 3.0 to arrive at the element-wise mean, but doing so twenty-five times, one for each position in the tile.

To compute an element-wise local aggregate, tiles need to have the same dimensions. In this case, both tiles have 5 rows and 5 columns. If we tried to compute an element-wise local aggregate over the DataFrame without equal tile dimensions, we would get a runtime error.

rf.agg(rf_agg_local_mean('tile')) \
    .first()[0].cells.data  # display the contents of the Tile array
array([[3.086, 2.881, 2.859, 3.088, 2.98 ],
       [3.013, 2.874, 3.075, 2.988, 2.939],
       [2.902, 3.015, 2.937, 2.947, 3.007],
       [3.002, 2.975, 3.01 , 3.073, 2.904],
       [2.961, 2.944, 3.012, 2.987, 3.056]])

Cell Counts Example

We can also count the total number of data and NoData cells over all the tiles in a DataFrame using rf_agg_data_cells and rf_agg_no_data_cells. There are ~3.8 million data cells and ~1.9 million NoData cells in this DataFrame. See the section on “NoData” handling for additional discussion on handling missing data.

rf = spark.read.raster('https://s22s-test-geotiffs.s3.amazonaws.com/MCD43A4.006/11/05/2018233/MCD43A4.A2018233.h11v05.006.2018242035530_B02.TIF')
stats = rf.agg(rf_agg_data_cells('proj_raster'), rf_agg_no_data_cells('proj_raster'))
stats
rf_agg_data_cells(proj_raster) rf_agg_no_data_cells(proj_raster)
3825959 1934041

Statistical Summaries

The statistical summary functions return a summary of cell values: number of data cells, number of NoData cells, minimum, maximum, mean, and variance, which can be computed as a tile aggregate, a DataFrame aggregate, or an element-wise local aggregate.

The rf_tile_stats function computes summary statistics separately for each row in a tile column as shown below.

rf = spark.read.raster('https://s22s-test-geotiffs.s3.amazonaws.com/luray_snp/B02.tif')
stats = rf.select(rf_tile_stats('proj_raster').alias('stats'))

stats.printSchema()
root
 |-- stats: struct (nullable = true)
 |    |-- data_cells: long (nullable = false)
 |    |-- no_data_cells: long (nullable = false)
 |    |-- min: double (nullable = false)
 |    |-- max: double (nullable = false)
 |    |-- mean: double (nullable = false)
 |    |-- variance: double (nullable = false)
stats.select('stats.min', 'stats.max', 'stats.mean', 'stats.variance')

Showing only top 5 rows.

min max mean variance
199.0 5331.0 455.5312957763667 68728.44405841625
187.0 1911.0 310.51860046386815 13849.46165963797
181.0 1530.0 259.9732513427734 3501.4270036255507
170.0 2535.0 292.7441253662109 11137.773410618069
152.0 3641.0 290.0180511474609 14295.314752891418

The rf_agg_stats function aggregates over all of the tiles in a DataFrame and returns a statistical summary of all cell values as shown below.

stats = rf.agg(rf_agg_stats('proj_raster').alias('stats')) \
    .select('stats.min', 'stats.max', 'stats.mean', 'stats.variance')
stats
min max mean variance
3.0 12103.0 542.1327946489893 685615.201702677

The rf_agg_local_stats function computes the element-wise local aggregate statistical summary as shown below. The DataFrame used in the previous two code blocks has unequal tile dimensions, so a different DataFrame is used in this code block to avoid a runtime error.

rf = spark.createDataFrame([
    Row(id=1, tile=t1),
    Row(id=3, tile=t1 * 3),
    Row(id=5, tile=t1 * 5)
]).agg(rf_agg_local_stats('tile').alias('stats'))
    
agg_local_stats = rf.select('stats.min', 'stats.max', 'stats.mean', 'stats.variance').collect()

for r in agg_local_stats:
    for stat in r.asDict():
        print(stat, ':\n', r[stat], '\n')
min :
 Tile(dimensions=[5, 5], cell_type=CellType(float64, nan), cells=
[[1.0717808680594876 0.8670403306104628 0.8643800348943269
  1.0387945525974258 0.9780646776865474]
 [1.0680277153484548 0.8395983669648295 1.0226149771778756
  0.9534852980699325 0.9896123846642095]
 [0.8277225264968968 1.1489150607310257 0.864041463874536
  0.85432943117754 0.9599492130315758]
 [1.036490473333024 0.9868775697800407 1.0866450795221134
  1.1638225649187852 0.9206091011274136]
 [1.0784492634747642 0.8458677799246253 0.9665528423601178
  1.0033842448442043 1.1086921814137851]]) 

max :
 Tile(dimensions=[5, 5], cell_type=CellType(float64, nan), cells=
[[5.358904340297438 4.335201653052314 4.321900174471635 5.193972762987129
  4.890323388432737]
 [5.340138576742274 4.197991834824148 5.113074885889378 4.767426490349663
  4.948061923321047]
 [4.1386126324844845 5.744575303655129 4.320207319372679 4.2716471558877
  4.799746065157879]
 [5.18245236666512 4.9343878489002035 5.433225397610567 5.819112824593926
  4.603045505637068]
 [5.3922463173738215 4.2293388996231265 4.83276421180059
  5.016921224221021 5.543460907068925]]) 

mean :
 Tile(dimensions=[5, 5], cell_type=CellType(float64, nan), cells=
[[3.215342604178463 2.6011209918313885 2.593140104682981
  3.116383657792278 2.9341940330596423]
 [3.2040831460453645 2.5187951008944887 3.0678449315336267
  2.8604558942097977 2.9688371539926286]
 [2.4831675794906904 3.446745182193077 2.592124391623608 2.56298829353262
  2.879847639094727]
 [3.1094714199990725 2.9606327093401226 3.25993523856634
  3.4914676947563557 2.7618273033822405]
 [3.2353477904242927 2.537603339773876 2.8996585270803537
  3.0101527345326127 3.326076544241355]]) 

variance :
 Tile(dimensions=[5, 5], cell_type=CellType(float64, nan), cells=
[[3.0632379443689306 2.0046904930802683 1.9924075859304473
  2.877584326682893 2.5509613699682365]
 [3.041821868673173 1.8798011141600224 2.78864371079602 2.424357903028021
  2.6115537916820895]
 [1.8269988823210834 3.5200155113988725 1.9908470701185363
  1.9463434052697037 2.4573399775998457]
 [2.8648333368269743 2.5971395672932207 3.148793410265654
  3.6119545669710416 2.2600563122096657]
 [3.1014741703713664 1.9079794696389696 2.491265058865128
  2.684746514137533 3.2778622750081556]])

Histogram

The rf_tile_histogram function computes a count of cell values within each row of tile and outputs a bins array with the schema below. In the graph below, we have plotted each bin’s value on the x-axis and count on the y-axis for the tile in the first row of the DataFrame.

import matplotlib.pyplot as plt

rf = spark.read.raster('https://s22s-test-geotiffs.s3.amazonaws.com/MCD43A4.006/11/05/2018233/MCD43A4.A2018233.h11v05.006.2018242035530_B02.TIF')

hist_df = rf.select(rf_tile_histogram('proj_raster')['bins'].alias('bins'))
hist_df.printSchema()

bins_row = hist_df.first()
values = [int(bin['value']) for bin in bins_row.bins]
counts = [int(bin['count']) for bin in bins_row.bins]

plt.hist(values, weights=counts, bins=100)
plt.show()
root
 |-- bins: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- value: double (nullable = false)
 |    |    |-- count: long (nullable = false)

The rf_agg_approx_histogram function computes a count of cell values across all of the rows of tile in a DataFrame or group. In the example below, the range of the y-axis is significantly wider than the range of the y-axis on the previous histogram since this histogram was computed for all cell values in the DataFrame.

bins_list = rf.agg(
    rf_agg_approx_histogram('proj_raster')['bins'].alias('bins')
    ).collect()
values = [int(row['value']) for row in bins_list[0].bins]
counts = [int(row['count']) for row in bins_list[0].bins]

plt.hist(values, weights=counts, bins=100)
plt.show()