# Copyright 2022 MosaicML Composer authors# SPDX-License-Identifier: Apache-2.0"""Log to `Weights and Biases <https://wandb.ai/>`_."""from__future__importannotationsimportatexitimportcopyimportosimportpathlibimportreimportsysimporttempfileimporttextwrapimportwarningsfromtypingimportTYPE_CHECKING,Any,Dict,List,Optional,Sequence,Unionimportnumpyasnpimporttorchfromcomposer.loggers.loggerimportLoggerfromcomposer.loggers.logger_destinationimportLoggerDestinationfromcomposer.utilsimportMissingConditionalImportError,distifTYPE_CHECKING:fromcomposer.coreimportState__all__=['WandBLogger']
[docs]classWandBLogger(LoggerDestination):"""Log to `Weights and Biases <https://wandb.ai/>`_. Args: project (str, optional): WandB project name. group (str, optional): WandB group name. name (str, optional): WandB run name. If not specified, the :attr:`.State.run_name` will be used. entity (str, optional): WandB entity name. tags (List[str], optional): WandB tags. log_artifacts (bool, optional): Whether to log `artifacts <https://docs.wandb.ai/ref/python/artifact>`_ (Default: ``False``). rank_zero_only (bool, optional): Whether to log only on the rank-zero process. When logging `artifacts <https://docs.wandb.ai/ref/python/artifact>`_, it is highly recommended to log on all ranks. Artifacts from ranks โฅ1 will not be stored, which may discard pertinent information. For example, when using Deepspeed ZeRO, it would be impossible to restore from checkpoints without artifacts from all ranks (default: ``False``). init_kwargs (Dict[str, Any], optional): Any additional init kwargs ``wandb.init`` (see `WandB documentation <https://docs.wandb.ai/ref/python/init>`_). """def__init__(self,project:Optional[str]=None,group:Optional[str]=None,name:Optional[str]=None,entity:Optional[str]=None,tags:Optional[List[str]]=None,log_artifacts:bool=False,rank_zero_only:bool=True,init_kwargs:Optional[Dict[str,Any]]=None,)->None:try:importwandbexceptImportErrorase:raiseMissingConditionalImportError(extra_deps_group='wandb',conda_package='wandb',conda_channel='conda-forge')fromedelwandb# unusediflog_artifactsandrank_zero_onlyanddist.get_world_size()>1:warnings.warn(('When logging artifacts, `rank_zero_only` should be set to False. ''Artifacts from other ranks will not be collected, leading to a loss of information required to ''restore from checkpoints.'))self._enabled=(notrank_zero_only)ordist.get_global_rank()==0ifinit_kwargsisNone:init_kwargs={}ifprojectisnotNone:init_kwargs['project']=projectifgroupisnotNone:init_kwargs['group']=groupifnameisnotNone:init_kwargs['name']=nameifentityisnotNone:init_kwargs['entity']=entityiftagsisnotNone:init_kwargs['tags']=tagsself._rank_zero_only=rank_zero_onlyself._log_artifacts=log_artifactsself._init_kwargs=init_kwargsself._is_in_atexit=False# Set these variable directly to allow fetching an Artifact **without** initializing a WandB run# When used as a LoggerDestination, these values are overriden from global rank 0 to all ranks on Event.INITself.entity=entityself.project=projectself.run_dir:Optional[str]=Nonedef_set_is_in_atexit(self):self._is_in_atexit=Truedeflog_hyperparameters(self,hyperparameters:Dict[str,Any]):ifself._enabled:importwandbwandb.config.update(hyperparameters)deflog_metrics(self,metrics:Dict[str,Any],step:Optional[int]=None)->None:ifself._enabled:importwandb# wandb.log alters the metrics dictionary object, so we deepcopy to avoid# side effects.metrics_copy=copy.deepcopy(metrics)wandb.log(metrics_copy,step)deflog_images(self,images:Union[np.ndarray,torch.Tensor,Sequence[Union[np.ndarray,torch.Tensor]]],name:str='Images',channels_last:bool=False,step:Optional[int]=None,masks:Optional[Dict[str,Union[np.ndarray,torch.Tensor,Sequence[Union[np.ndarray,torch.Tensor]]]]]=None,mask_class_labels:Optional[Dict[int,str]]=None,use_table:bool=False,):ifself._enabled:importwandbifnotisinstance(images,Sequence)andimages.ndim<=3:images=[images]# _convert_to_wandb_image doesn't include wrapping with wandb.Image to future# proof for when we support masks.images_generator=(_convert_to_wandb_image(image,channels_last)forimageinimages)ifmasksisnotNone:# Create a generator that yields masks in the format wandb wants.wandb_masks_generator=_create_wandb_masks_generator(masks,mask_class_labels,channels_last=channels_last)wandb_images=(wandb.Image(im,masks=mask_dict)forim,mask_dictinzip(images_generator,wandb_masks_generator))else:wandb_images=(wandb.Image(image)forimageinimages_generator)ifuse_table:table=wandb.Table(columns=[name])forwandb_imageinwandb_images:table.add_data(wandb_image)wandb.log({name+' Table':table},step=step)else:wandb.log({name:list(wandb_images)},step=step)defstate_dict(self)->Dict[str,Any]:importwandb# Storing these fields in the state dict to support run resuming in the future.ifself._enabled:ifwandb.runisNone:raiseValueError('wandb module must be initialized before serialization.')# If WandB is disabled, most things are RunDisabled objects, which are not# pickleable due to overriding __getstate__ but not __setstate__ifwandb.run.disabled:return{}else:return{'name':wandb.run.name,'project':wandb.run.project,'entity':wandb.run.entity,'id':wandb.run.id,'group':wandb.run.group}else:return{}definit(self,state:State,logger:Logger)->None:importwandbdellogger# unused# Use the state run name if the name is not set.if'name'notinself._init_kwargsorself._init_kwargs['name']isNone:self._init_kwargs['name']=state.run_name# Adjust name and group based on `rank_zero_only`.ifnotself._rank_zero_only:name=self._init_kwargs['name']self._init_kwargs['name']+=f'-rank{dist.get_global_rank()}'self._init_kwargs['group']=self._init_kwargs['group']if'group'inself._init_kwargselsenameifself._enabled:wandb.init(**self._init_kwargs)assertwandb.runisnotNone,'The wandb run is set after init'entity_and_project=[str(wandb.run.entity),str(wandb.run.project)]self.run_dir=wandb.run.diratexit.register(self._set_is_in_atexit)else:entity_and_project=[None,None]# Share the entity and project across all ranks, so they are available on ranks that did not initialize wandbdist.broadcast_object_list(entity_and_project)self.entity,self.project=entity_and_projectassertself.entityisnotNone,'entity should be defined'assertself.projectisnotNone,'project should be defined'defupload_file(self,state:State,remote_file_name:str,file_path:pathlib.Path,*,overwrite:bool):deloverwrite# unusedifself._enabledandself._log_artifacts:importwandb# Some WandB-specific alias extractiontimestamp=state.timestampaliases=['latest',f'ep{int(timestamp.epoch)}-ba{int(timestamp.batch)}']# replace all unsupported characters with periods# Only alpha-numeric, periods, hyphens, and underscores are supported by wandb.new_remote_file_name=re.sub(r'[^a-zA-Z0-9-_\.]','.',remote_file_name)ifnew_remote_file_name!=remote_file_name:warnings.warn(('WandB permits only alpha-numeric, periods, hyphens, and underscores in file names. 'f"The file with name '{remote_file_name}' will be stored as '{new_remote_file_name}'."))extension=new_remote_file_name.split('.')[-1]metadata={f'timestamp/{k}':vfor(k,v)instate.timestamp.state_dict().items()}# if evaluating, also log the evaluation timestampifstate.dataloaderisnotstate.train_dataloader:# TODO If not actively training, then it is impossible to tell from the state whether# the trainer is evaluating or predicting. Assuming evaluation in this case.metadata.update({f'eval_timestamp/{k}':vfor(k,v)instate.eval_timestamp.state_dict().items()})wandb_artifact=wandb.Artifact(name=new_remote_file_name,type=extension,metadata=metadata,)wandb_artifact.add_file(os.path.abspath(file_path))wandb.log_artifact(wandb_artifact,aliases=aliases)
[docs]defcan_upload_files(self)->bool:"""Whether the logger supports uploading files."""returnTrue
defdownload_file(self,remote_file_name:str,destination:str,overwrite:bool=False,progress_bar:bool=True,):# Note: WandB doesn't support progress bars for downloadingdelprogress_barimportwandbimportwandb.errors# using the wandb.Api() to support retrieving artifacts on ranks where# artifacts are not initializedapi=wandb.Api()ifnotself.entityornotself.project:raiseRuntimeError('get_file_artifact can only be called after running init()')# replace all unsupported characters with periods# Only alpha-numeric, periods, hyphens, and underscores are supported by wandb.if':'notinremote_file_name:remote_file_name+=':latest'new_remote_file_name=re.sub(r'[^a-zA-Z0-9-_\.:]','.',remote_file_name)ifnew_remote_file_name!=remote_file_name:warnings.warn(('WandB permits only alpha-numeric, periods, hyphens, and underscores in file names. 'f"The file with name '{remote_file_name}' will be stored as '{new_remote_file_name}'."))try:wandb_artifact=api.artifact('/'.join([self.entity,self.project,new_remote_file_name]))exceptwandb.errors.CommErrorase:if'does not contain artifact'instr(e):raiseFileNotFoundError(f'WandB Artifact {new_remote_file_name} not found')fromeraiseewithtempfile.TemporaryDirectory()astmpdir:wandb_artifact_folder=os.path.join(tmpdir,'wandb_artifact_folder')wandb_artifact.download(root=wandb_artifact_folder)wandb_artifact_names=os.listdir(wandb_artifact_folder)# We only log one file per artifactiflen(wandb_artifact_names)>1:raiseRuntimeError('Found more than one file in WandB artifact. We assume the checkpoint is the only file in the WandB artifact.')wandb_artifact_name=wandb_artifact_names[0]wandb_artifact_path=os.path.join(wandb_artifact_folder,wandb_artifact_name)ifoverwrite:os.replace(wandb_artifact_path,destination)else:os.rename(wandb_artifact_path,destination)defpost_close(self)->None:importwandb# Cleaning up on post_close so all artifacts are uploadedifnotself._enabledorwandb.runisNoneorself._is_in_atexit:# Don't call wandb.finish if there is no run, or# the script is in an atexit, since wandb also hooks into atexit# and it will error if wandb.finish is called from the Composer atexit hook# after it is called from the wandb atexit hookreturnexc_tpe,exc_info,tb=sys.exc_info()if(exc_tpe,exc_info,tb)==(None,None,None):wandb.finish(0)else:# record there was an errorwandb.finish(1)
def_convert_to_wandb_image(image:Union[np.ndarray,torch.Tensor],channels_last:bool)->np.ndarray:ifisinstance(image,torch.Tensor):image=image.data.cpu().numpy()# Error out for empty arrays or weird arrays of dimension 0.ifnp.any(np.equal(image.shape,0)):raiseValueError(f'Got an image (shape {image.shape}) with at least one dimension being 0! ')# Squeeze any singleton dimensions and then add them back in if image dimension# less than 3.image=image.squeeze()# Add in length-one dimensions to get back up to 3# putting channels last.ifimage.ndim==1:image=np.expand_dims(image,(1,2))channels_last=Trueifimage.ndim==2:image=np.expand_dims(image,2)channels_last=Trueifimage.ndim!=3:raiseValueError(textwrap.dedent(f'''Input image must be 3 dimensions, but instead got {image.ndim} dims at shape: {image.shape} Your input image was interpreted as a batch of {image.ndim} -dimensional images because you either specified a{image.ndim+1}D image or a list of {image.ndim}D images. Please specify either a 4D image of a list of 3D images'''))assertisinstance(image,np.ndarray)ifnotchannels_last:image=image.transpose(1,2,0)returnimagedef_convert_to_wandb_mask(mask:Union[np.ndarray,torch.Tensor],channels_last:bool)->np.ndarray:mask=_convert_to_wandb_image(mask,channels_last)mask=mask.squeeze()ifmask.ndim!=2:raiseValueError(f'Mask must be a 2D array, but instead got array of shape: {mask.shape}')returnmaskdef_preprocess_mask_data(masks:Dict[str,Union[np.ndarray,torch.Tensor,Sequence[Union[np.ndarray,torch.Tensor]]]],channels_last:bool)->Dict[str,np.ndarray]:preprocesssed_masks={}formask_name,mask_datainmasks.items():ifnotisinstance(mask_data,Sequence):mask_data=mask_data.squeeze()ifmask_data.ndim==2:mask_data=[mask_data]preprocesssed_masks[mask_name]=np.stack([_convert_to_wandb_mask(mask,channels_last)formaskinmask_data])returnpreprocesssed_masksdef_create_wandb_masks_generator(masks:Dict[str,Union[np.ndarray,torch.Tensor,Sequence[Union[np.ndarray,torch.Tensor]]]],mask_class_labels:Optional[Dict[int,str]],channels_last:bool):preprocessed_masks:Dict[str,np.ndarray]=_preprocess_mask_data(masks,channels_last)forall_masks_for_single_exampleinzip(*list(preprocessed_masks.values())):mask_dict={name:{'mask_data':mask}forname,maskinzip(masks.keys(),all_masks_for_single_example)}ifmask_class_labelsisnotNone:forkinmask_dict.keys():mask_dict[k].update({'class_labels':mask_class_labels})yieldmask_dict