如有研究需要数据,请与QQ1807232;或微信扫码添加客服咨询

python+gdal高分二号影像批量预处理(正射校正+融合+金字塔创建)

本文仅仅实现了正射校正、全色和多光谱的融合以及创建金字塔。如果需要的是高精度的定量反演,还需要进行辐射定标和大气校正。

1. 前言

本文仅仅实现了正射校正全色和多光谱的融合以及创建金字塔。如果需要的是高精度的定量反演,还需要进行辐射定标和大气校正。

数据来源为中国资源卫星应用中心陆地观测卫星数据服务平台。

2. 具体步骤

2.1 解压

因为数据下载下来为压缩包,单独解压比较麻烦,所以将解压也放在了批量操作里。

import tarfile

# 解压
def unpackage(file_name):
    
    # 提取解压文件夹名
    if ".tar.gz" in file_name:
        out_dir = file_name.split(".tar.gz")[0]
    else:
        out_dir = file_name.split(".")[0]
    # 进行解压
    with tarfile.open(file_name) as file:
        file.extractall(path = out_dir)
    return out_dir

2.2 正射校正

正射校正一般是通过在像片上选取一些地面控制点,并利用原来已经获取的该像片范围内的DEM对影像同时进行倾斜改正和投影差改正,将影像重采样成正射影像。我们使用GDAL的Warp函数采用有理多项式系数RPC(Rational Polynomial Coefficient)校正。本地没有高精度DEM的话可以使用ENVI自带的GMTED2010.jp2,路径一般为"Exelis\ENVI53\data\GMTED2010.jp2"。

import os
from osgeo import gdal, osr


# 正射校正
def ortho(file_name, dem_name, res, out_file_name):
    
    dataset = gdal.Open(file_name, gdal.GA_ReadOnly)
    
    # 是否北半球
    is_north = 1 if os.path.basename(file_name).split('_')[3][0] == 'N' else 0
    # 计算UTM区号
    zone = str(int(float(os.path.basename(file_name).split('_')[2][1:])/6) + 31)
    zone = int('326' + zone) if is_north else int('327' + zone)
    
    dstSRS = osr.SpatialReference()
    dstSRS.ImportFromEPSG(zone)
    
    # dstSRS = 'EPSG:4326'
    
    tmp_ds = gdal.Warp(out_file_name, dataset, format = 'GTiff', 
                       xRes = res, yRes = res, dstSRS = dstSRS, 
                       rpc = True, resampleAlg=gdal.GRIORA_Bilinear,
                       transformerOptions=["RPC_DEM="+dem_name])
    dataset = tds = None

2.3 多光谱全色融合

对于融合算法,我们使用GDAL自带的算法gdal_pansharpen.py,一般在Lib/site-packages/osgeo/utils文件夹中,为了方便运行,对脚本文件稍作修改

import os
import os.path
import sys
from osgeo import gdal


def DoesDriverHandleExtension(drv, ext):
    exts = drv.GetMetadataItem(gdal.DMD_EXTENSIONS)
    return exts is not None and exts.lower().find(ext.lower()) >= 0


def GetExtension(filename):
    ext = os.path.splitext(filename)[1]
    if ext.startswith('.'):
        ext = ext[1:]
    return ext


def GetOutputDriversFor(filename):
    drv_list = []
    ext = GetExtension(filename)
    for i in range(gdal.GetDriverCount()):
        drv = gdal.GetDriver(i)
        if (drv.GetMetadataItem(gdal.DCAP_CREATE) is not None or
            drv.GetMetadataItem(gdal.DCAP_CREATECOPY) is not None) and \
           drv.GetMetadataItem(gdal.DCAP_RASTER) is not None:
            if ext and DoesDriverHandleExtension(drv, ext):
                drv_list.append(drv.ShortName)
            else:
                prefix = drv.GetMetadataItem(gdal.DMD_CONNECTION_PREFIX)
                if prefix is not None and filename.lower().startswith(prefix.lower()):
                    drv_list.append(drv.ShortName)

    # GMT is registered before netCDF for opening reasons, but we want
    # netCDF to be used by default for output.
    if ext.lower() == 'nc' and not drv_list and \
       drv_list[0].upper() == 'GMT' and drv_list[1].upper() == 'NETCDF':
        drv_list = ['NETCDF', 'GMT']

    return drv_list


def GetOutputDriverFor(filename):
    drv_list = GetOutputDriversFor(filename)
    ext = GetExtension(filename)
    if not drv_list:
        if not ext:
            return 'GTiff'
        else:
            raise Exception("Cannot guess driver for %s" % filename)
    elif len(drv_list) > 1:
        print("Several drivers matching %s extension. Using %s" % (ext if ext else '', drv_list[0]))
    return drv_list[0]


def Usage():
    print('Usage: gdal_pansharpen [--help-general] pan_dataset {spectral_dataset[,band=num]}+ out_dataset')
    print('                       [-of format] [-b band]* [-w weight]*')
    print('                       [-r {nearest,bilinear,cubic,cubicspline,lanczos,average}]')
    print('                       [-threads {ALL_CPUS|number}] [-bitdepth val] [-nodata val]')
    print('                       [-spat_adjust {union,intersection,none,nonewithoutwarning}]')
    print('                       [-verbose_vrt] [-co NAME=VALUE]* [-q]')
    print('')
    print('Create a dataset resulting from a pansharpening operation.')
    return -1


def gdal_pansharpen(argv):

    argv = gdal.GeneralCmdLineProcessor(argv)
    if argv is None:
        return -1

    pan_name = None
    last_name = None
    spectral_ds = []
    spectral_bands = []
    out_name = None
    bands = []
    weights = []
    frmt = None
    creation_options = []
    callback = gdal.TermProgress_nocb
    resampling = None
    spat_adjust = None
    verbose_vrt = False
    num_threads = None
    bitdepth = None
    nodata = None

    i = 1
    argc = len(argv)
    while i < argc:
        if (argv[i] == '-of' or argv[i] == '-f') and i < len(argv) - 1:
            frmt = argv[i + 1]
            i = i + 1
        elif argv[i] == '-r' and i < len(argv) - 1:
            resampling = argv[i + 1]
            i = i + 1
        elif argv[i] == '-spat_adjust' and i < len(argv) - 1:
            spat_adjust = argv[i + 1]
            i = i + 1
        elif argv[i] == '-b' and i < len(argv) - 1:
            bands.append(int(argv[i + 1]))
            i = i + 1
        elif argv[i] == '-w' and i < len(argv) - 1:
            weights.append(float(argv[i + 1]))
            i = i + 1
        elif argv[i] == '-co' and i < len(argv) - 1:
            creation_options.append(argv[i + 1])
            i = i + 1
        elif argv[i] == '-threads' and i < len(argv) - 1:
            num_threads = argv[i + 1]
            i = i + 1
        elif argv[i] == '-bitdepth' and i < len(argv) - 1:
            bitdepth = argv[i + 1]
            i = i + 1
        elif argv[i] == '-nodata' and i < len(argv) - 1:
            nodata = argv[i + 1]
            i = i + 1
        elif argv[i] == '-q':
            callback = None
        elif argv[i] == '-verbose_vrt':
            verbose_vrt = True
        elif argv[i][0] == '-':
            sys.stderr.write('Unrecognized option : %s\n' % argv[i])
            return Usage()
        elif pan_name is None:
            pan_name = argv[i]
            pan_ds = gdal.Open(pan_name)
            if pan_ds is None:
                return 1
        else:
            # print(last_name)
            if last_name is not None:
                pos = last_name.find(',band=')
                # print(last_name)
                if pos > 0:
                    spectral_name = last_name[0:pos]
                    ds = gdal.Open(spectral_name)
                    if ds is None:
                        return 1
                    band_num = int(last_name[pos + len(',band='):])
                    band = ds.GetRasterBand(band_num)
                    spectral_ds.append(ds)
                    spectral_bands.append(band)
                else:
                    spectral_name = last_name
                    ds = gdal.Open(spectral_name)
                    if ds is None:
                        return 1
                    for j in range(ds.RasterCount):
                        spectral_ds.append(ds)
                        spectral_bands.append(ds.GetRasterBand(j + 1))

            last_name = argv[i]
            # print(last_name)

        i = i + 1
    
    # print(spectral_name)
    if pan_name is None or not spectral_bands:
        return Usage()
    out_name = last_name
    # print(out_name)
    if frmt is None:
        frmt = GetOutputDriverFor(out_name)

    if not bands:
        bands = [j + 1 for j in range(len(spectral_bands))]
    else:
        for band in bands:
            if band < 0 or band > len(spectral_bands):
                print('Invalid band number in -b: %d' % band)
                return 1

    if weights and len(weights) != len(spectral_bands):
        print('There must be as many -w values specified as input spectral bands')
        return 1

    vrt_xml = """<VRTDataset subClass="VRTPansharpenedDataset">\n"""
    if bands != [j + 1 for j in range(len(spectral_bands))]:
        for i, band in enumerate(bands):
            sband = spectral_bands[band - 1]
            datatype = gdal.GetDataTypeName(sband.DataType)
            colorname = gdal.GetColorInterpretationName(sband.GetColorInterpretation())
            vrt_xml += """  <VRTRasterBand dataType="%s" band="%d" subClass="VRTPansharpenedRasterBand">
      <ColorInterp>%s</ColorInterp>
  </VRTRasterBand>\n""" % (datatype, i + 1, colorname)

    vrt_xml += """  <PansharpeningOptions>\n"""

    if weights:
        vrt_xml += """      <AlgorithmOptions>\n"""
        vrt_xml += """        <Weights>"""
        for i, weight in enumerate(weights):
            if i > 0:
                vrt_xml += ","
            vrt_xml += "%.16g" % weight
        vrt_xml += "</Weights>\n"
        vrt_xml += """      </AlgorithmOptions>\n"""

    if resampling is not None:
        vrt_xml += '      <Resampling>%s</Resampling>\n' % resampling

    if num_threads is not None:
        vrt_xml += '      <NumThreads>%s</NumThreads>\n' % num_threads

    if bitdepth is not None:
        vrt_xml += '      <BitDepth>%s</BitDepth>\n' % bitdepth

    if nodata is not None:
        vrt_xml += '      <NoData>%s</NoData>\n' % nodata

    if spat_adjust is not None:
        vrt_xml += '      <SpatialExtentAdjustment>%s</SpatialExtentAdjustment>\n' % spat_adjust

    pan_relative = '0'
    if frmt.upper() == 'VRT':
        if not os.path.isabs(pan_name):
            pan_relative = '1'
            pan_name = os.path.relpath(pan_name, os.path.dirname(out_name))

    vrt_xml += """    <PanchroBand>
      <SourceFilename relativeToVRT="%s">%s</SourceFilename>
      <SourceBand>1</SourceBand>
    </PanchroBand>\n""" % (pan_relative, pan_name)

    for i, sband in enumerate(spectral_bands):
        dstband = ''
        for j, band in enumerate(bands):
            if i + 1 == band:
                dstband = ' dstBand="%d"' % (j + 1)
                break

        ms_relative = '0'
        ms_name = spectral_ds[i].GetDescription()
        if frmt.upper() == 'VRT':
            if not os.path.isabs(ms_name):
                ms_relative = '1'
                ms_name = os.path.relpath(ms_name, os.path.dirname(out_name))

        vrt_xml += """    <SpectralBand%s>
      <SourceFilename relativeToVRT="%s">%s</SourceFilename>
      <SourceBand>%d</SourceBand>
    </SpectralBand>\n""" % (dstband, ms_relative, ms_name, sband.GetBand())

    vrt_xml += """  </PansharpeningOptions>\n"""
    vrt_xml += """</VRTDataset>\n"""

    if frmt.upper() == 'VRT':
        f = gdal.VSIFOpenL(out_name, 'wb')
        if f is None:
            print('Cannot create %s' % out_name)
            return 1
        gdal.VSIFWriteL(vrt_xml, 1, len(vrt_xml), f)
        gdal.VSIFCloseL(f)
        if verbose_vrt:
            vrt_ds = gdal.Open(out_name, gdal.GA_Update)
            vrt_ds.SetMetadata(vrt_ds.GetMetadata())
        else:
            vrt_ds = gdal.Open(out_name)
        if vrt_ds is None:
            return 1

        return 0

    vrt_ds = gdal.Open(vrt_xml)
    out_ds = gdal.GetDriverByName(frmt).CreateCopy(out_name, vrt_ds, 0, creation_options, callback=callback)
    if out_ds is None:
        return 1
    return 0

# if __name__ == '__main__':
#     pan_path = r"E:\WangZhenQing\FGMS-Dataset\GF2_原始数据集\大宁县\GF2_PMS1_E110.6_N36.3_20200727_L1A0004953367\GF2_PMS1_E110.6_N36.3_20200727_L1A0004953367-PAN1_ortho_dem.tiff"
#     mss_path = r"E:\WangZhenQing\FGMS-Dataset\GF2_原始数据集\大宁县\GF2_PMS1_E110.6_N36.3_20200727_L1A0004953367\GF2_PMS1_E110.6_N36.3_20200727_L1A0004953367-MSS1_ortho_dem.tiff"
#     pansharpen_path = pan_path.replace("PAN1_ortho_dem.tiff", "pansharpen.tiff")
#     gdal_pansharpen(["pass",pan_path,mss_path,pansharpen_path])

2.4 创建金字塔

图像融合后文件会很大,后续arcgis打开时会卡顿,所以需要创建图像金字塔,直接利用BuildOverviews函数即可实现:

from osgeo import gdal

def build_pyramid(file_name):
    dataset = gdal.Open(file_name)
    dataset.BuildOverviews(overviewlist=[2, 4 ,8, 16])
    del dataset

2.5 主函数

使用glob获取文件夹内的所有压缩包,再接一个for循环就可以实现批量预处理了。

import glob
import os
from unpackage import unpackage
from ortho import ortho
from pansharpen import gdal_pansharpen
from osgeo import gdal
from build_pyramid import build_pyramid
import warnings
warnings.filterwarnings('ignore')

gdal.UseExceptions()

def preprocess(dem_path, tar_dir):
    
    tar_paths = glob.glob(tar_dir+"/*.tar.gz")
    tar_unpackage_dirs = []
    print("开始解压...")
    for tar_index, tar_path in enumerate(tar_paths):
        print(f"{tar_index+1}/{len(tar_paths)}")
        print(os.path.basename(tar_path))
        tar_unpackage_dir = unpackage(tar_path)
        tar_unpackage_dirs.append(tar_unpackage_dir)
    
    print("开始正射校正与融合...")
    for tar_unpackage_index, tar_unpackage_dir in enumerate(tar_unpackage_dirs):
        print(f"{tar_unpackage_index+1}/{len(tar_unpackage_dirs)}")
        
        # 全色数据正射校正
        pan_path = glob.glob(tar_unpackage_dir+"/*PAN*.tiff")[0]
        pan_ortho_path = pan_path.replace(".tiff", "_ortho.tiff")
        pan_res = 0.8
        print(os.path.basename(pan_path),"正射校正...")
        ortho(pan_path, dem_path, pan_res, pan_ortho_path)
        
        # 多光谱数据正射校正
        mss_path = glob.glob(tar_unpackage_dir+"/*MSS*.tiff")[0]
        mss_ortho_path = mss_path.replace(".tiff", "_ortho.tiff")
        mss_res = 3.2
        print(os.path.basename(mss_path),"正射校正...")
        ortho(mss_path, dem_path, mss_res, mss_ortho_path)
        
        # 融合
        print("融合...")
        pansharpen_path = pan_ortho_path.split("PAN")[0]+"pansharpen.tiff"
        gdal_pansharpen(["pass",pan_ortho_path,mss_ortho_path,pansharpen_path])
        print("创建金字塔...")
        build_pyramid(pansharpen_path)
    
if __name__ == '__main__':
    
    # 采用envi自带的dem
    dem_path = r"C:\setup\Exelis\ENVI53\data\GMTED2010.jp2"
    # 一堆压缩包所在文件夹
    tar_dir = r"文件夹"
    preprocess(dem_path, tar_dir)

文章来源于"地理遥感生态网科学数据注册与出版系统"(www.gisrs.cn);