From 49cc107a7568a598626715e278ae289d48c102a0 Mon Sep 17 00:00:00 2001 From: "yufei.chen" Date: Thu, 27 Oct 2022 15:05:46 +0800 Subject: [PATCH] add distributed multicards codes to hashnerf MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit link #I5Y00Y hashnerf 添加多卡代码 Signed-off-by: yufei.chen --- 3d-reconstruction/hashnerf/main_nerf.py | 54 ++++++- 3d-reconstruction/hashnerf/nerf/provider.py | 9 +- 3d-reconstruction/hashnerf/nerf/utils.py | 147 +++++++++++++++++--- 3d-reconstruction/hashnerf/readme.md | 14 +- 3d-reconstruction/hashnerf/train.sh | 7 + 5 files changed, 198 insertions(+), 33 deletions(-) create mode 100644 3d-reconstruction/hashnerf/train.sh diff --git a/3d-reconstruction/hashnerf/main_nerf.py b/3d-reconstruction/hashnerf/main_nerf.py index eb6cbd1ca..b1cb3cf9b 100644 --- a/3d-reconstruction/hashnerf/main_nerf.py +++ b/3d-reconstruction/hashnerf/main_nerf.py @@ -9,12 +9,49 @@ from functools import partial from loss import huber_loss #torch.autograd.set_detect_anomaly(True) +torch.backends.cudnn.benchmark = True + + +def setup_for_distributed(is_master): + """ + This function disables printing when not in master process + """ + import builtins as __builtin__ + builtin_print = __builtin__.print + + def print(*args, **kwargs): + force = kwargs.pop('force', False) + if is_master or force: + builtin_print(*args, **kwargs) + + __builtin__.print = print + + +def init_distributed_mode(args): + if args.num_gpus > 1 and 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: + env_rank = int(os.environ["RANK"]) + env_world_size = int(os.environ["WORLD_SIZE"]) + env_gpu = int(os.environ['LOCAL_RANK']) + else: + print('Not using distributed mode') + return + + dist_backend = "nccl" + print('| distributed init (rank {}) (size {})'.format(env_rank,env_world_size), flush=True) + torch.distributed.init_process_group(backend=dist_backend, init_method='env://', + world_size=env_world_size, rank=env_rank) + + torch.cuda.set_device(env_gpu) + torch.distributed.barrier() + setup_for_distributed(env_rank == 0) + if __name__ == '__main__': print('[TIMESTAMP] start time:', time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())) parser = argparse.ArgumentParser() parser.add_argument('path', type=str) parser.add_argument('-O', action='store_true', help="equals --fp16 --cuda_ray --preload") + parser.add_argument('--num_gpus', type=int, default=1) parser.add_argument('--test', action='store_true', help="test mode") parser.add_argument('--workspace', type=str, default='workspace') parser.add_argument('--seed', type=int, default=0) @@ -39,7 +76,7 @@ if __name__ == '__main__': parser.add_argument('--view', type=str, default='yaw', help="view direction:random or yaw") ### evaluate options - parser.add_argument('--eval_interval', type=int, default=5, help="eval_interval") + parser.add_argument('--eval_interval', type=int, default=1, help="eval_interval") ### network backbone options parser.add_argument('--fp16', action='store_true', help="use amp mixed precision training") @@ -77,6 +114,8 @@ if __name__ == '__main__': opt = parser.parse_args() + init_distributed_mode(opt) + if opt.O: opt.fp16 = True opt.cuda_ray = True @@ -122,7 +161,16 @@ if __name__ == '__main__': #criterion = torch.nn.HuberLoss(reduction='none', beta=0.1) # only available after torch 1.10 ? device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - + + if opt.num_gpus > 1: + env_rank = int(os.environ["RANK"]) + env_world_size = int(os.environ["WORLD_SIZE"]) + env_gpu = int(os.environ['LOCAL_RANK']) + else: + env_rank = 0 + env_world_size = 1 + env_gpu = 0 + if opt.test: metrics = [PSNRMeter(), LPIPSMeter(device=device)] @@ -156,7 +204,7 @@ if __name__ == '__main__': scheduler = lambda optimizer: optim.lr_scheduler.LambdaLR(optimizer, lambda iter: 0.1 ** min(iter / opt.iters, 1)) metrics = [PSNRMeter(), LPIPSMeter(device=device)] - trainer = Trainer('ngp', opt, model, device=device, workspace=opt.workspace, optimizer=optimizer, criterion=criterion, ema_decay=0.95, fp16=opt.fp16, lr_scheduler=scheduler, scheduler_update_every_step=True, metrics=metrics, use_checkpoint=opt.ckpt, eval_interval=opt.eval_interval) + trainer = Trainer('ngp', opt, model, local_rank=env_gpu, world_size=env_world_size, device=device, workspace=opt.workspace, optimizer=optimizer, criterion=criterion, ema_decay=0.95, fp16=opt.fp16, lr_scheduler=scheduler, scheduler_update_every_step=True, metrics=metrics, use_checkpoint=opt.ckpt, eval_interval=opt.eval_interval) if opt.gui: gui = NeRFGUI(opt, trainer, train_loader) diff --git a/3d-reconstruction/hashnerf/nerf/provider.py b/3d-reconstruction/hashnerf/nerf/provider.py index 1449966e8..ded5a201d 100644 --- a/3d-reconstruction/hashnerf/nerf/provider.py +++ b/3d-reconstruction/hashnerf/nerf/provider.py @@ -96,6 +96,7 @@ class NeRFDataset: super().__init__() self.opt = opt + self.num_gpus = opt.num_gpus self.device = device self.type = type # train, val, test self.downscale = downscale @@ -342,7 +343,13 @@ class NeRFDataset: size = len(self.poses) if self.training and self.rand_pose > 0: size += size // self.rand_pose # index >= size means we use random pose. - loader = DataLoader(list(range(size)), batch_size=1, collate_fn=self.collate, shuffle=self.training, num_workers=0) + if self.num_gpus > 1: + train_sampler = torch.utils.data.distributed.DistributedSampler(list(range(size))) + loader = DataLoader(list(range(size)), batch_size=1, sampler=train_sampler, collate_fn=self.collate, num_workers=0) + else: + loader = DataLoader(list(range(size)), batch_size=1, collate_fn=self.collate, shuffle=self.training, num_workers=0) + loader._data = self # an ugly fix... we need to access error_map & poses in trainer. loader.has_gt = self.images is not None return loader + diff --git a/3d-reconstruction/hashnerf/nerf/utils.py b/3d-reconstruction/hashnerf/nerf/utils.py index dc6342e2e..3fcdf151e 100644 --- a/3d-reconstruction/hashnerf/nerf/utils.py +++ b/3d-reconstruction/hashnerf/nerf/utils.py @@ -30,6 +30,113 @@ from torch_ema import ExponentialMovingAverage from packaging import version as pver import lpips +from torch.distributed.algorithms.join import ( + Join, + Joinable, + JoinHook, +) +import logging + + + +class DDP(torch.nn.parallel.DistributedDataParallel): + def __init__(self, module, device_ids, output_device=None, dim=0, broadcast_buffers=True, process_group=None, bucket_cap_mb=25, find_unused_parameters=False, check_reduction=False): + super().__init__(module, device_ids, output_device, dim, broadcast_buffers, process_group, bucket_cap_mb, find_unused_parameters, check_reduction) + + def render(self, *inputs, **kwargs): + with torch.autograd.profiler.record_function("DistributedDataParallel.forward"): + if torch.is_grad_enabled() and self.require_backward_grad_sync: + self.logger.set_runtime_stats_and_log() + self.num_iterations += 1 + self.reducer.prepare_for_forward() + + # Notify the join context that this process has not joined, if + # needed + work = Join.notify_join_context(self) + if work: + self.reducer._set_forward_pass_work_handle( + work, self._divide_by_initial_world_size + ) + + # Calling _rebuild_buckets before forward compuation, + # It may allocate new buckets before deallocating old buckets + # inside _rebuild_buckets. To save peak memory usage, + # call _rebuild_buckets before the peak memory usage increases + # during forward computation. + # This should be called only once during whole training period. + if torch.is_grad_enabled() and self.reducer._rebuild_buckets(): + logging.info("Reducer buckets have been rebuilt in this iteration.") + self._has_rebuilt_buckets = True + + if self.require_forward_param_sync: + self._sync_params() + + if self._join_config.enable: + # Notify joined ranks whether they should sync in backwards pass or not. + self._check_global_requires_backward_grad_sync(is_joined_rank=False) + + if self.device_ids: + inputs, kwargs = self.to_kwargs(inputs, kwargs, self.device_ids[0]) + output = self.module.render(*inputs[0], **kwargs[0]) + else: + output = self.module.render(*inputs, **kwargs) + + if torch.is_grad_enabled() and self.require_backward_grad_sync: + self.require_forward_param_sync = True + # We'll return the output object verbatim since it is a freeform + # object. We need to find any tensors in this object, though, + # because we need to figure out which parameters were used during + # this forward pass, to ensure we short circuit reduction for any + # unused parameters. Only if `find_unused_parameters` is set. + if self.find_unused_parameters and not self.static_graph: + # Do not need to populate this for static graph. + self.reducer.prepare_for_backward(list(_find_tensors(output))) + else: + self.reducer.prepare_for_backward([]) + else: + self.require_forward_param_sync = False + + # TODO: DDPSink is currently enabled for unused parameter detection and + # static graph training for first iteration. + if (self.find_unused_parameters and not self.static_graph) or ( + self.static_graph and self.num_iterations == 1 + ): + state_dict = { + 'static_graph': self.static_graph, + 'num_iterations': self.num_iterations, + } + + output_tensor_list, treespec, output_is_rref = _tree_flatten_with_rref( + output + ) + output_placeholders = [None for _ in range(len(output_tensor_list))] + # Do not touch tensors that have no grad_fn, which can cause issues + # such as https://github.com/pytorch/pytorch/issues/60733 + for i, output in enumerate(output_tensor_list): + if torch.is_tensor(output) and output.grad_fn is None: + output_placeholders[i] = output + + # When find_unused_parameters=True, makes tensors which require grad + # run through the DDPSink backward pass. When not all outputs are + # used in loss, this makes those corresponding tensors receive + # undefined gradient which the reducer then handles to ensure + # param.grad field is not touched and we don't error out. + passthrough_tensor_list = _DDPSink.apply( + self.reducer, + state_dict, + *output_tensor_list, + ) + for i in range(len(output_placeholders)): + if output_placeholders[i] is None: + output_placeholders[i] = passthrough_tensor_list[i] + + # Reconstruct output data structure. + output = _tree_unflatten_with_rref( + output_placeholders, treespec, output_is_rref + ) + return output + + def custom_meshgrid(*args): # ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid @@ -261,9 +368,10 @@ class LPIPSMeter: def update(self, preds, truths): preds, truths = self.prepare_inputs(preds, truths) # [B, H, W, 3] --> [B, 3, H, W], range in [0, 1] - v = self.fn(truths, preds, normalize=True).item() # normalize=True: [0, 1] to [-1, 1] - self.V += v - self.N += 1 + v = self.fn(truths, preds, normalize=True) # normalize=True: [0, 1] to [-1, 1] + n = v.shape[0] + self.V += v.sum().item() + self.N += n def measure(self): return self.V / self.N @@ -322,9 +430,11 @@ class Trainer(object): self.console = Console() model.to(self.device) + self._model = model if self.world_size > 1: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) - model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank]) + # model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank]) + model = DDP(model, device_ids=[local_rank]) self.model = model if isinstance(criterion, nn.Module): @@ -339,7 +449,10 @@ class Trainer(object): if optimizer is None: self.optimizer = optim.Adam(self.model.parameters(), lr=0.001, weight_decay=5e-4) # naive adam else: - self.optimizer = optimizer(self.model) + if self.world_size > 1: + self.optimizer = optimizer(self.model.module) + else: + self.optimizer = optimizer(self.model) if lr_scheduler is None: self.lr_scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lambda epoch: 1) # fake scheduler @@ -470,7 +583,6 @@ class Trainer(object): # outputs = self.model.render(rays_o, rays_d, staged=False, bg_color=bg_color, perturb=True, force_all_rays=False if self.opt.patch_size == 1 else True, **vars(self.opt)) outputs = self.model.render(rays_o, rays_d, staged=False, bg_color=bg_color, perturb=True, force_all_rays=True, **vars(self.opt)) - pred_rgb = outputs['image'] # MSE loss @@ -521,11 +633,10 @@ class Trainer(object): # pred_weights_sum = outputs['weights_sum'] + 1e-8 # loss_ws = - 1e-1 * pred_weights_sum * torch.log(pred_weights_sum) # entropy to encourage weights_sum to be 0 or 1. # loss = loss + loss_ws.mean() - + return pred_rgb, gt_rgb, loss def eval_step(self, data): - rays_o = data['rays_o'] # [B, N, 3] rays_d = data['rays_d'] # [B, N, 3] images = data['images'] # [B, H, W, 3/4] @@ -597,8 +708,8 @@ class Trainer(object): self.writer = tensorboardX.SummaryWriter(os.path.join(self.workspace, "run", self.name)) # mark untrained region (i.e., not covered by any camera from the training dataset) - if self.model.cuda_ray: - self.model.mark_untrained_grid(train_loader._data.poses, train_loader._data.intrinsics) + if self._model.cuda_ray: + self._model.mark_untrained_grid(train_loader._data.poses, train_loader._data.intrinsics) # get a ref to error_map self.error_map = train_loader._data.error_map @@ -698,7 +809,7 @@ class Trainer(object): data = next(loader) # update grid every 16 steps - if self.model.cuda_ray and self.global_step % self.opt.update_extra_interval == 0: + if self._model.cuda_ray and self.global_step % self.opt.update_extra_interval == 0: with torch.cuda.amp.autocast(enabled=self.fp16): self.model.update_extra_state() @@ -812,7 +923,7 @@ class Trainer(object): for data in loader: # update grid every 16 steps - if self.model.cuda_ray and self.global_step % self.opt.update_extra_interval == 0: + if self._model.cuda_ray and self.global_step % self.opt.update_extra_interval == 0: with torch.cuda.amp.autocast(enabled=self.fp16): self.model.update_extra_state() @@ -984,9 +1095,9 @@ class Trainer(object): 'stats': self.stats, } - if self.model.cuda_ray: - state['mean_count'] = self.model.mean_count - state['mean_density'] = self.model.mean_density + if self._model.cuda_ray: + state['mean_count'] = self._model.mean_count + state['mean_density'] = self._model.mean_density if full: state['optimizer'] = self.optimizer.state_dict() @@ -1062,11 +1173,11 @@ class Trainer(object): if self.ema is not None and 'ema' in checkpoint_dict: self.ema.load_state_dict(checkpoint_dict['ema']) - if self.model.cuda_ray: + if self._model.cuda_ray: if 'mean_count' in checkpoint_dict: - self.model.mean_count = checkpoint_dict['mean_count'] + self._model.mean_count = checkpoint_dict['mean_count'] if 'mean_density' in checkpoint_dict: - self.model.mean_density = checkpoint_dict['mean_density'] + self._model.mean_density = checkpoint_dict['mean_density'] if model_only: return diff --git a/3d-reconstruction/hashnerf/readme.md b/3d-reconstruction/hashnerf/readme.md index 2721aaf8f..0bdf50f04 100644 --- a/3d-reconstruction/hashnerf/readme.md +++ b/3d-reconstruction/hashnerf/readme.md @@ -15,6 +15,9 @@ pip install -r requirements.txt We use the same data format as instant-ngp, e.g., [armadillo](https://github.com/NVlabs/instant-ngp/blob/master/data/sdf/armadillo.obj) and [fox](https://github.com/NVlabs/instant-ngp/tree/master/data/nerf/fox). Please download and put them under `./data`. +#bash +bash train.sh + We also support self-captured dataset and converting other formats (e.g., LLFF, Tanks&Temples, Mip-NeRF 360) to the nerf-compatible format, with details in the following code block. Supported datasets @@ -31,7 +34,6 @@ First time running will take some time to compile the CUDA extensions. # for the colmap dataset, the default dataset setting `--bound 2 --scale 0.33` is used. python main_nerf.py data/fox --workspace trial_nerf # fp32 mode python main_nerf.py data/fox --workspace trial_nerf --fp16 # fp16 mode (pytorch amp) -python main_nerf.py data/fox --workspace trial_nerf --fp16 --ff # fp16 mode + FFMLP (this repo's implementation) # one for all: -O means --fp16 --cuda_ray --preload, which usually gives the best results balanced on speed & performance. @@ -54,16 +56,6 @@ python main_nerf.py data/nerf_synthetic/lego --workspace trial_nerf -O --bound 1 python main_nerf.py data/nerf_llff_data/fern --workspace trial_nerf -O ``` -```bash -# for custom dataset, you should: -# 1. take a video / many photos from different views -# 2. put the video under a path like ./data/custom/video.mp4 or the images under ./data/custom/images/*.jpg. -# 3. call the preprocess code: (should install ffmpeg and colmap first! refer to the file for more options) -python scripts/colmap2nerf.py --video ./data/custom/video.mp4 --run_colmap # if use video -python scripts/colmap2nerf.py --images ./data/custom/images/ --run_colmap # if use images -# 4. it should create the transform.json, and you can train with: (you'll need to try with different scale & bound & dt_gamma to make the object correctly located in the bounding box and render fluently.) -python main_nerf.py data/custom --workspace trial_nerf_custom -O --gui --scale 2.0 --bound 1.0 --dt_gamma 0.02 -``` ## Results on BI-V100 diff --git a/3d-reconstruction/hashnerf/train.sh b/3d-reconstruction/hashnerf/train.sh new file mode 100644 index 000000000..1117c2685 --- /dev/null +++ b/3d-reconstruction/hashnerf/train.sh @@ -0,0 +1,7 @@ +#1gpu +python3 -m torch.distributed.launch --nproc_per_node=1 --use_env \ + main_nerf.py data/fox --workspace trial_nerf --num_gpus=1 +#8gpus +#CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python3 -m torch.distributed.launch --nproc_per_node=8 --use_env \ +# main_nerf.py data/fox --workspace trial_nerf --num_gpus=8 + -- Gitee