python+gdal高分二号影像批量预处理(正射校正+融合+金字塔创建)
本文仅仅实现了正射校正、全色和多光谱的融合以及创建金字塔。如果需要的是高精度的定量反演,还需要进行辐射定标和大气校正。
详细信息
1. 前言
本文仅仅实现了正射校正、全色和多光谱的融合以及创建金字塔。如果需要的是高精度的定量反演,还需要进行辐射定标和大气校正。
数据来源为中国资源卫星应用中心陆地观测卫星数据服务平台。
2. 具体步骤
2.1 解压
因为数据下载下来为压缩包,单独解压比较麻烦,所以将解压也放在了批量操作里。
import tarfile
# 解压
def unpackage(file_name):
# 提取解压文件夹名
if ".tar.gz" in file_name:
out_dir = file_name.split(".tar.gz")[0]
else:
out_dir = file_name.split(".")[0]
# 进行解压
with tarfile.open(file_name) as file:
file.extractall(path = out_dir)
return out_dir
2.2 正射校正
正射校正一般是通过在像片上选取一些地面控制点,并利用原来已经获取的该像片范围内的DEM对影像同时进行倾斜改正和投影差改正,将影像重采样成正射影像。我们使用GDAL的Warp函数采用有理多项式系数RPC(Rational Polynomial Coefficient)校正。本地没有高精度DEM的话可以使用ENVI自带的GMTED2010.jp2,路径一般为"Exelis\ENVI53\data\GMTED2010.jp2"。
import os
from osgeo import gdal, osr
# 正射校正
def ortho(file_name, dem_name, res, out_file_name):
dataset = gdal.Open(file_name, gdal.GA_ReadOnly)
# 是否北半球
is_north = 1 if os.path.basename(file_name).split('_')[3][0] == 'N' else 0
# 计算UTM区号
zone = str(int(float(os.path.basename(file_name).split('_')[2][1:])/6) + 31)
zone = int('326' + zone) if is_north else int('327' + zone)
dstSRS = osr.SpatialReference()
dstSRS.ImportFromEPSG(zone)
# dstSRS = 'EPSG:4326'
tmp_ds = gdal.Warp(out_file_name, dataset, format = 'GTiff',
xRes = res, yRes = res, dstSRS = dstSRS,
rpc = True, resampleAlg=gdal.GRIORA_Bilinear,
transformerOptions=["RPC_DEM="+dem_name])
dataset = tds = None
2.3 多光谱全色融合
对于融合算法,我们使用GDAL自带的算法gdal_pansharpen.py,一般在Lib/site-packages/osgeo/utils文件夹中,为了方便运行,对脚本文件稍作修改:
import os
import os.path
import sys
from osgeo import gdal
def DoesDriverHandleExtension(drv, ext):
exts = drv.GetMetadataItem(gdal.DMD_EXTENSIONS)
return exts is not None and exts.lower().find(ext.lower()) >= 0
def GetExtension(filename):
ext = os.path.splitext(filename)[1]
if ext.startswith('.'):
ext = ext[1:]
return ext
def GetOutputDriversFor(filename):
drv_list = []
ext = GetExtension(filename)
for i in range(gdal.GetDriverCount()):
drv = gdal.GetDriver(i)
if (drv.GetMetadataItem(gdal.DCAP_CREATE) is not None or
drv.GetMetadataItem(gdal.DCAP_CREATECOPY) is not None) and \
drv.GetMetadataItem(gdal.DCAP_RASTER) is not None:
if ext and DoesDriverHandleExtension(drv, ext):
drv_list.append(drv.ShortName)
else:
prefix = drv.GetMetadataItem(gdal.DMD_CONNECTION_PREFIX)
if prefix is not None and filename.lower().startswith(prefix.lower()):
drv_list.append(drv.ShortName)
# GMT is registered before netCDF for opening reasons, but we want
# netCDF to be used by default for output.
if ext.lower() == 'nc' and not drv_list and \
drv_list[0].upper() == 'GMT' and drv_list[1].upper() == 'NETCDF':
drv_list = ['NETCDF', 'GMT']
return drv_list
def GetOutputDriverFor(filename):
drv_list = GetOutputDriversFor(filename)
ext = GetExtension(filename)
if not drv_list:
if not ext:
return 'GTiff'
else:
raise Exception("Cannot guess driver for %s" % filename)
elif len(drv_list) > 1:
print("Several drivers matching %s extension. Using %s" % (ext if ext else '', drv_list[0]))
return drv_list[0]
def Usage():
print('Usage: gdal_pansharpen [--help-general] pan_dataset {spectral_dataset[,band=num]}+ out_dataset')
print(' [-of format] [-b band]* [-w weight]*')
print(' [-r {nearest,bilinear,cubic,cubicspline,lanczos,average}]')
print(' [-threads {ALL_CPUS|number}] [-bitdepth val] [-nodata val]')
print(' [-spat_adjust {union,intersection,none,nonewithoutwarning}]')
print(' [-verbose_vrt] [-co NAME=VALUE]* [-q]')
print('')
print('Create a dataset resulting from a pansharpening operation.')
return -1
def gdal_pansharpen(argv):
argv = gdal.GeneralCmdLineProcessor(argv)
if argv is None:
return -1
pan_name = None
last_name = None
spectral_ds = []
spectral_bands = []
out_name = None
bands = []
weights = []
frmt = None
creation_options = []
callback = gdal.TermProgress_nocb
resampling = None
spat_adjust = None
verbose_vrt = False
num_threads = None
bitdepth = None
nodata = None
i = 1
argc = len(argv)
while i < argc:
if (argv[i] == '-of' or argv[i] == '-f') and i < len(argv) - 1:
frmt = argv[i + 1]
i = i + 1
elif argv[i] == '-r' and i < len(argv) - 1:
resampling = argv[i + 1]
i = i + 1
elif argv[i] == '-spat_adjust' and i < len(argv) - 1:
spat_adjust = argv[i + 1]
i = i + 1
elif argv[i] == '-b' and i < len(argv) - 1:
bands.append(int(argv[i + 1]))
i = i + 1
elif argv[i] == '-w' and i < len(argv) - 1:
weights.append(float(argv[i + 1]))
i = i + 1
elif argv[i] == '-co' and i < len(argv) - 1:
creation_options.append(argv[i + 1])
i = i + 1
elif argv[i] == '-threads' and i < len(argv) - 1:
num_threads = argv[i + 1]
i = i + 1
elif argv[i] == '-bitdepth' and i < len(argv) - 1:
bitdepth = argv[i + 1]
i = i + 1
elif argv[i] == '-nodata' and i < len(argv) - 1:
nodata = argv[i + 1]
i = i + 1
elif argv[i] == '-q':
callback = None
elif argv[i] == '-verbose_vrt':
verbose_vrt = True
elif argv[i][0] == '-':
sys.stderr.write('Unrecognized option : %s\n' % argv[i])
return Usage()
elif pan_name is None:
pan_name = argv[i]
pan_ds = gdal.Open(pan_name)
if pan_ds is None:
return 1
else:
# print(last_name)
if last_name is not None:
pos = last_name.find(',band=')
# print(last_name)
if pos > 0:
spectral_name = last_name[0:pos]
ds = gdal.Open(spectral_name)
if ds is None:
return 1
band_num = int(last_name[pos + len(',band='):])
band = ds.GetRasterBand(band_num)
spectral_ds.append(ds)
spectral_bands.append(band)
else:
spectral_name = last_name
ds = gdal.Open(spectral_name)
if ds is None:
return 1
for j in range(ds.RasterCount):
spectral_ds.append(ds)
spectral_bands.append(ds.GetRasterBand(j + 1))
last_name = argv[i]
# print(last_name)
i = i + 1
# print(spectral_name)
if pan_name is None or not spectral_bands:
return Usage()
out_name = last_name
# print(out_name)
if frmt is None:
frmt = GetOutputDriverFor(out_name)
if not bands:
bands = [j + 1 for j in range(len(spectral_bands))]
else:
for band in bands:
if band < 0 or band > len(spectral_bands):
print('Invalid band number in -b: %d' % band)
return 1
if weights and len(weights) != len(spectral_bands):
print('There must be as many -w values specified as input spectral bands')
return 1
vrt_xml = """<VRTDataset subClass="VRTPansharpenedDataset">\n"""
if bands != [j + 1 for j in range(len(spectral_bands))]:
for i, band in enumerate(bands):
sband = spectral_bands[band - 1]
datatype = gdal.GetDataTypeName(sband.DataType)
colorname = gdal.GetColorInterpretationName(sband.GetColorInterpretation())
vrt_xml += """ <VRTRasterBand dataType="%s" band="%d" subClass="VRTPansharpenedRasterBand">
<ColorInterp>%s</ColorInterp>
</VRTRasterBand>\n""" % (datatype, i + 1, colorname)
vrt_xml += """ <PansharpeningOptions>\n"""
if weights:
vrt_xml += """ <AlgorithmOptions>\n"""
vrt_xml += """ <Weights>"""
for i, weight in enumerate(weights):
if i > 0:
vrt_xml += ","
vrt_xml += "%.16g" % weight
vrt_xml += "</Weights>\n"
vrt_xml += """ </AlgorithmOptions>\n"""
if resampling is not None:
vrt_xml += ' <Resampling>%s</Resampling>\n' % resampling
if num_threads is not None:
vrt_xml += ' <NumThreads>%s</NumThreads>\n' % num_threads
if bitdepth is not None:
vrt_xml += ' <BitDepth>%s</BitDepth>\n' % bitdepth
if nodata is not None:
vrt_xml += ' <NoData>%s</NoData>\n' % nodata
if spat_adjust is not None:
vrt_xml += ' <SpatialExtentAdjustment>%s</SpatialExtentAdjustment>\n' % spat_adjust
pan_relative = '0'
if frmt.upper() == 'VRT':
if not os.path.isabs(pan_name):
pan_relative = '1'
pan_name = os.path.relpath(pan_name, os.path.dirname(out_name))
vrt_xml += """ <PanchroBand>
<SourceFilename relativeToVRT="%s">%s</SourceFilename>
<SourceBand>1</SourceBand>
</PanchroBand>\n""" % (pan_relative, pan_name)
for i, sband in enumerate(spectral_bands):
dstband = ''
for j, band in enumerate(bands):
if i + 1 == band:
dstband = ' dstBand="%d"' % (j + 1)
break
ms_relative = '0'
ms_name = spectral_ds[i].GetDescription()
if frmt.upper() == 'VRT':
if not os.path.isabs(ms_name):
ms_relative = '1'
ms_name = os.path.relpath(ms_name, os.path.dirname(out_name))
vrt_xml += """ <SpectralBand%s>
<SourceFilename relativeToVRT="%s">%s</SourceFilename>
<SourceBand>%d</SourceBand>
</SpectralBand>\n""" % (dstband, ms_relative, ms_name, sband.GetBand())
vrt_xml += """ </PansharpeningOptions>\n"""
vrt_xml += """</VRTDataset>\n"""
if frmt.upper() == 'VRT':
f = gdal.VSIFOpenL(out_name, 'wb')
if f is None:
print('Cannot create %s' % out_name)
return 1
gdal.VSIFWriteL(vrt_xml, 1, len(vrt_xml), f)
gdal.VSIFCloseL(f)
if verbose_vrt:
vrt_ds = gdal.Open(out_name, gdal.GA_Update)
vrt_ds.SetMetadata(vrt_ds.GetMetadata())
else:
vrt_ds = gdal.Open(out_name)
if vrt_ds is None:
return 1
return 0
vrt_ds = gdal.Open(vrt_xml)
out_ds = gdal.GetDriverByName(frmt).CreateCopy(out_name, vrt_ds, 0, creation_options, callback=callback)
if out_ds is None:
return 1
return 0
# if __name__ == '__main__':
# pan_path = r"E:\WangZhenQing\FGMS-Dataset\GF2_原始数据集\大宁县\GF2_PMS1_E110.6_N36.3_20200727_L1A0004953367\GF2_PMS1_E110.6_N36.3_20200727_L1A0004953367-PAN1_ortho_dem.tiff"
# mss_path = r"E:\WangZhenQing\FGMS-Dataset\GF2_原始数据集\大宁县\GF2_PMS1_E110.6_N36.3_20200727_L1A0004953367\GF2_PMS1_E110.6_N36.3_20200727_L1A0004953367-MSS1_ortho_dem.tiff"
# pansharpen_path = pan_path.replace("PAN1_ortho_dem.tiff", "pansharpen.tiff")
# gdal_pansharpen(["pass",pan_path,mss_path,pansharpen_path])
2.4 创建金字塔
图像融合后文件会很大,后续arcgis打开时会卡顿,所以需要创建图像金字塔,直接利用BuildOverviews函数即可实现:
from osgeo import gdal
def build_pyramid(file_name):
dataset = gdal.Open(file_name)
dataset.BuildOverviews(overviewlist=[2, 4 ,8, 16])
del dataset
2.5 主函数
使用glob获取文件夹内的所有压缩包,再接一个for循环就可以实现批量预处理了。
import glob
import os
from unpackage import unpackage
from ortho import ortho
from pansharpen import gdal_pansharpen
from osgeo import gdal
from build_pyramid import build_pyramid
import warnings
warnings.filterwarnings('ignore')
gdal.UseExceptions()
def preprocess(dem_path, tar_dir):
tar_paths = glob.glob(tar_dir+"/*.tar.gz")
tar_unpackage_dirs = []
print("开始解压...")
for tar_index, tar_path in enumerate(tar_paths):
print(f"{tar_index+1}/{len(tar_paths)}")
print(os.path.basename(tar_path))
tar_unpackage_dir = unpackage(tar_path)
tar_unpackage_dirs.append(tar_unpackage_dir)
print("开始正射校正与融合...")
for tar_unpackage_index, tar_unpackage_dir in enumerate(tar_unpackage_dirs):
print(f"{tar_unpackage_index+1}/{len(tar_unpackage_dirs)}")
# 全色数据正射校正
pan_path = glob.glob(tar_unpackage_dir+"/*PAN*.tiff")[0]
pan_ortho_path = pan_path.replace(".tiff", "_ortho.tiff")
pan_res = 0.8
print(os.path.basename(pan_path),"正射校正...")
ortho(pan_path, dem_path, pan_res, pan_ortho_path)
# 多光谱数据正射校正
mss_path = glob.glob(tar_unpackage_dir+"/*MSS*.tiff")[0]
mss_ortho_path = mss_path.replace(".tiff", "_ortho.tiff")
mss_res = 3.2
print(os.path.basename(mss_path),"正射校正...")
ortho(mss_path, dem_path, mss_res, mss_ortho_path)
# 融合
print("融合...")
pansharpen_path = pan_ortho_path.split("PAN")[0]+"pansharpen.tiff"
gdal_pansharpen(["pass",pan_ortho_path,mss_ortho_path,pansharpen_path])
print("创建金字塔...")
build_pyramid(pansharpen_path)
if __name__ == '__main__':
# 采用envi自带的dem
dem_path = r"C:\setup\Exelis\ENVI53\data\GMTED2010.jp2"
# 一堆压缩包所在文件夹
tar_dir = r"文件夹"
preprocess(dem_path, tar_dir)