Unsupervised Machine Learning

In this example, we will demonstrate how to fit and score an unsupervised learning model with a sample of Landsat 8 data.

Imports and Data Preparation

We import various Spark components needed to construct our Pipeline.

import pandas as pd
from pyrasterframes import TileExploder
from pyrasterframes.rasterfunctions import *

from pyspark.ml.feature import VectorAssembler
from pyspark.ml.clustering import KMeans
from pyspark.ml import Pipeline

The first step is to create a Spark DataFrame of our imagery data. To achieve that we will create a catalog DataFrame using the pattern from the I/O page. In the catalog, each row represents a distinct area and time, and each column is the URI to a band’s image product. The resulting Spark DataFrame may have many rows per URI, with a column corresponding to each band.

filenamePattern = "https://rasterframes.s3.amazonaws.com/samples/elkton/L8-B{}-Elkton-VA.tiff"
catalog_df = pd.DataFrame([
    {'b' + str(b): filenamePattern.format(b) for b in range(1, 8)}
])

tile_size = 256
df = spark.read.raster(catalog_df, catalog_col_names=catalog_df.columns, tile_size=tile_size)
df = df.withColumn('crs', rf_crs(df.b1)) \
       .withColumn('extent', rf_extent(df.b1))
df.printSchema()
root
 |-- b1_path: string (nullable = false)
 |-- b2_path: string (nullable = false)
 |-- b3_path: string (nullable = false)
 |-- b4_path: string (nullable = false)
 |-- b5_path: string (nullable = false)
 |-- b6_path: string (nullable = false)
 |-- b7_path: string (nullable = false)
 |-- b1: struct (nullable = true)
 |    |-- tile_context: struct (nullable = true)
 |    |    |-- extent: struct (nullable = false)
 |    |    |    |-- xmin: double (nullable = false)
 |    |    |    |-- ymin: double (nullable = false)
 |    |    |    |-- xmax: double (nullable = false)
 |    |    |    |-- ymax: double (nullable = false)
 |    |    |-- crs: struct (nullable = false)
 |    |    |    |-- crsProj4: string (nullable = false)
 |    |-- tile: tile (nullable = false)
 |-- b2: struct (nullable = true)
 |    |-- tile_context: struct (nullable = true)
 |    |    |-- extent: struct (nullable = false)
 |    |    |    |-- xmin: double (nullable = false)
 |    |    |    |-- ymin: double (nullable = false)
 |    |    |    |-- xmax: double (nullable = false)
 |    |    |    |-- ymax: double (nullable = false)
 |    |    |-- crs: struct (nullable = false)
 |    |    |    |-- crsProj4: string (nullable = false)
 |    |-- tile: tile (nullable = false)
 |-- b3: struct (nullable = true)
 |    |-- tile_context: struct (nullable = true)
 |    |    |-- extent: struct (nullable = false)
 |    |    |    |-- xmin: double (nullable = false)
 |    |    |    |-- ymin: double (nullable = false)
 |    |    |    |-- xmax: double (nullable = false)
 |    |    |    |-- ymax: double (nullable = false)
 |    |    |-- crs: struct (nullable = false)
 |    |    |    |-- crsProj4: string (nullable = false)
 |    |-- tile: tile (nullable = false)
 |-- b4: struct (nullable = true)
 |    |-- tile_context: struct (nullable = true)
 |    |    |-- extent: struct (nullable = false)
 |    |    |    |-- xmin: double (nullable = false)
 |    |    |    |-- ymin: double (nullable = false)
 |    |    |    |-- xmax: double (nullable = false)
 |    |    |    |-- ymax: double (nullable = false)
 |    |    |-- crs: struct (nullable = false)
 |    |    |    |-- crsProj4: string (nullable = false)
 |    |-- tile: tile (nullable = false)
 |-- b5: struct (nullable = true)
 |    |-- tile_context: struct (nullable = true)
 |    |    |-- extent: struct (nullable = false)
 |    |    |    |-- xmin: double (nullable = false)
 |    |    |    |-- ymin: double (nullable = false)
 |    |    |    |-- xmax: double (nullable = false)
 |    |    |    |-- ymax: double (nullable = false)
 |    |    |-- crs: struct (nullable = false)
 |    |    |    |-- crsProj4: string (nullable = false)
 |    |-- tile: tile (nullable = false)
 |-- b6: struct (nullable = true)
 |    |-- tile_context: struct (nullable = true)
 |    |    |-- extent: struct (nullable = false)
 |    |    |    |-- xmin: double (nullable = false)
 |    |    |    |-- ymin: double (nullable = false)
 |    |    |    |-- xmax: double (nullable = false)
 |    |    |    |-- ymax: double (nullable = false)
 |    |    |-- crs: struct (nullable = false)
 |    |    |    |-- crsProj4: string (nullable = false)
 |    |-- tile: tile (nullable = false)
 |-- b7: struct (nullable = true)
 |    |-- tile_context: struct (nullable = true)
 |    |    |-- extent: struct (nullable = false)
 |    |    |    |-- xmin: double (nullable = false)
 |    |    |    |-- ymin: double (nullable = false)
 |    |    |    |-- xmax: double (nullable = false)
 |    |    |    |-- ymax: double (nullable = false)
 |    |    |-- crs: struct (nullable = false)
 |    |    |    |-- crsProj4: string (nullable = false)
 |    |-- tile: tile (nullable = false)
 |-- crs: struct (nullable = true)
 |    |-- crsProj4: string (nullable = false)
 |-- extent: struct (nullable = true)
 |    |-- xmin: double (nullable = false)
 |    |-- ymin: double (nullable = false)
 |    |-- xmax: double (nullable = false)
 |    |-- ymax: double (nullable = false)

In this small example, all the images in our catalog_df have the same CRS, which we verify in the code snippet below. The crs object will be useful for visualization later.

crses = df.select('crs.crsProj4').distinct().collect()
print('Found ', len(crses), 'distinct CRS: ', crses)
assert len(crses) == 1
crs = crses[0]['crsProj4']
Found  1 distinct CRS:  [Row(crsProj4='+proj=utm +zone=17 +datum=WGS84 +units=m +no_defs ')]

Create ML Pipeline

SparkML requires that each observation be in its own row, and features for each observation be packed into a single Vector. For this unsupervised learning problem, we will treat each pixel as an observation and each band as a feature. The first step is to “explode” the tiles into a single row per pixel. In RasterFrames, generally a pixel is called a cell.

exploder = TileExploder()

To “vectorize” the band columns, we use the SparkML VectorAssembler. Each of the seven bands is a different feature.

assembler = VectorAssembler() \
    .setInputCols(list(catalog_df.columns)) \
    .setOutputCol("features")

For this problem, we will use the K-means clustering algorithm and configure our model to have 5 clusters.

kmeans = KMeans().setK(5).setFeaturesCol('features')

We can combine the above stages into a single Pipeline.

pipeline = Pipeline().setStages([exploder, assembler, kmeans])

Fit the Model and Score

Fitting the pipeline actually executes exploding the tiles, assembling the features vectors, and fitting the K-means clustering model.

model = pipeline.fit(df)

We can use the transform function to score the training data in the fitted pipeline model. This will add a column called prediction with the closest cluster identifier.

clustered = model.transform(df)

Now let’s take a look at some sample output.

clustered.select('prediction', 'extent', 'column_index', 'row_index', 'features')

Showing only top 5 rows.

prediction extent column_index row_index features
0 [703986.502389, 4249551.61978, 709549.093643, 4254601.8671] 0 0 [9470.0,8491.0,7805.0,6697.0,17507.0,10338.0,7235.0]
0 [703986.502389, 4249551.61978, 709549.093643, 4254601.8671] 1 0 [9566.0,8607.0,8046.0,6898.0,18504.0,11545.0,7877.0]
2 [703986.502389, 4249551.61978, 709549.093643, 4254601.8671] 2 0 [9703.0,8808.0,8377.0,7222.0,20556.0,13207.0,8686.0]
2 [703986.502389, 4249551.61978, 709549.093643, 4254601.8671] 3 0 [9856.0,8983.0,8565.0,7557.0,19479.0,13203.0,9065.0]
1 [703986.502389, 4249551.61978, 709549.093643, 4254601.8671] 4 0 [10105.0,9270.0,8851.0,7912.0,19074.0,12737.0,8947.0]

If we want to inspect the model statistics, the SparkML API requires us to go through this unfortunate contortion to access the clustering results:

cluster_stage = model.stages[2]

We can then compute the sum of squared distances of points to their nearest center, which is elemental to most cluster quality metrics.

metric = cluster_stage.computeCost(clustered)
print("Within set sum of squared errors: %s" % metric)
Within set sum of squared errors: 249577233410.3331

Visualize Prediction

We can recreate the tiled data structure using the metadata added by the TileExploder pipeline stage.

from pyrasterframes.rf_types import CellType

retiled = clustered.groupBy('extent', 'crs') \
    .agg(
        rf_assemble_tile('column_index', 'row_index', 'prediction',
            tile_size, tile_size, CellType.int8())
)

Next we will write the output to a GeoTiff file. Doing so in this case works quickly and well for a few specific reasons that may not hold in all cases. We can write the data at full resolution, by omitting the raster_dimensions argument, because we know the input raster dimensions are small. Also, the data is all in a single CRS, as we demonstrated above. Because the catalog_df is only a single row, we know the output GeoTIFF value at a given location corresponds to a single input. Finally, the retiled DataFrame only has a single Tile column, so the band interpretation is trivial.

import rasterio
output_tif = 'unsupervised.tif'

retiled.write.geotiff(output_tif, crs=crs)

with rasterio.open(output_tif) as src:
    for b in range(1, src.count + 1):
        print("Tags on band", b, src.tags(b))
    display(src)
Tags on band 1 {'RF_COL': 'prediction'}
<open DatasetReader name='unsupervised.tif' mode='r'>