token-gallery-contracts/contracts/LibDiamond.sol

149 lines
7.2 KiB
Solidity

// SPDX-License-Identifier: MIT
pragma solidity ^0.7.2;
pragma experimental ABIEncoderV2;
/******************************************************************************\
* Author: Nick Mudge
*
* Implementation of Diamond facet.
/******************************************************************************/
import "./IDiamondCut.sol";
library LibDiamond {
bytes32 constant DIAMOND_STORAGE_POSITION = keccak256("diamond.standard.diamond.storage");
struct FacetAddressAndSelectorPosition {
address facetAddress;
uint16 selectorPosition;
}
struct DiamondStorage {
// function selector => facet address and selector position in selectors array
mapping(bytes4 => FacetAddressAndSelectorPosition) facetAddressAndSelectorPosition;
bytes4[] selectors;
mapping(bytes4 => bool) supportedInterfaces;
// owner of the contract
address contractOwner;
}
function diamondStorage() internal pure returns (DiamondStorage storage ds) {
bytes32 position = DIAMOND_STORAGE_POSITION;
assembly {
ds.slot := position
}
}
event DiamondCut(IDiamondCut.FacetCut[] diamondCut, address init, bytes _calldata);
// Internal function version of diamondCut
// This code is almost the same as the external diamondCut,
// except it is using 'Facet[] memory _diamondCut' instead of
// 'Facet[] calldata _diamondCut'.
// The code is duplicated to prevent copying calldata to memory which
// causes an error for a two dimensional array.
function diamondCut(
IDiamondCut.FacetCut[] memory _diamondCut,
address _init,
bytes memory _calldata
) internal {
uint256 selectorCount = diamondStorage().selectors.length;
for (uint256 facetIndex; facetIndex < _diamondCut.length; facetIndex++) {
selectorCount = addReplaceRemoveFacetSelectors(
selectorCount,
_diamondCut[facetIndex].facetAddress,
_diamondCut[facetIndex].action,
_diamondCut[facetIndex].functionSelectors
);
}
emit DiamondCut(_diamondCut, _init, _calldata);
initializeDiamondCut(_init, _calldata);
}
function addReplaceRemoveFacetSelectors(
uint256 _selectorCount,
address _newFacetAddress,
IDiamondCut.FacetCutAction _action,
bytes4[] memory _selectors
) internal returns (uint256) {
DiamondStorage storage ds = diamondStorage();
require(_selectors.length > 0, "LibDiamondCut: No selectors in facet to cut");
if (_action == IDiamondCut.FacetCutAction.Add) {
require(_newFacetAddress != address(0), "LibDiamondCut: Add facet can't be address(0)");
enforceHasContractCode(_newFacetAddress, "LibDiamondCut: Add facet has no code");
for (uint256 selectorIndex; selectorIndex < _selectors.length; selectorIndex++) {
bytes4 selector = _selectors[selectorIndex];
address oldFacetAddress = ds.facetAddressAndSelectorPosition[selector].facetAddress;
require(oldFacetAddress == address(0), "LibDiamondCut: Can't add function that already exists");
ds.facetAddressAndSelectorPosition[selector] = FacetAddressAndSelectorPosition(
_newFacetAddress,
uint16(_selectorCount)
);
ds.selectors.push(selector);
_selectorCount++;
}
} else if(_action == IDiamondCut.FacetCutAction.Replace) {
require(_newFacetAddress != address(0), "LibDiamondCut: Replace facet can't be address(0)");
enforceHasContractCode(_newFacetAddress, "LibDiamondCut: Replace facet has no code");
for (uint256 selectorIndex; selectorIndex < _selectors.length; selectorIndex++) {
bytes4 selector = _selectors[selectorIndex];
address oldFacetAddress = ds.facetAddressAndSelectorPosition[selector].facetAddress;
// only useful if immutable functions exist
require(oldFacetAddress != address(this), "LibDiamondCut: Can't replace immutable function");
require(oldFacetAddress != _newFacetAddress, "LibDiamondCut: Can't replace function with same function");
require(oldFacetAddress != address(0), "LibDiamondCut: Can't replace function that doesn't exist");
// replace old facet address
ds.facetAddressAndSelectorPosition[selector].facetAddress = _newFacetAddress;
}
} else if(_action == IDiamondCut.FacetCutAction.Remove) {
require(_newFacetAddress == address(0), "LibDiamondCut: Remove facet address must be address(0)");
for (uint256 selectorIndex; selectorIndex < _selectors.length; selectorIndex++) {
bytes4 selector = _selectors[selectorIndex];
FacetAddressAndSelectorPosition memory oldFacetAddressAndSelectorPosition = ds.facetAddressAndSelectorPosition[selector];
require(oldFacetAddressAndSelectorPosition.facetAddress != address(0), "LibDiamondCut: Can't remove function that doesn't exist");
// only useful if immutable functions exist
require(oldFacetAddressAndSelectorPosition.facetAddress != address(this), "LibDiamondCut: Can't remove immutable function.");
// replace selector with last selector
if (oldFacetAddressAndSelectorPosition.selectorPosition != _selectorCount - 1) {
bytes4 lastSelector = ds.selectors[_selectorCount - 1];
ds.selectors[oldFacetAddressAndSelectorPosition.selectorPosition] = lastSelector;
ds.facetAddressAndSelectorPosition[lastSelector].selectorPosition = oldFacetAddressAndSelectorPosition.selectorPosition;
}
// delete last selector
ds.selectors.pop();
delete ds.facetAddressAndSelectorPosition[selector];
_selectorCount--;
}
} else {
revert("LibDiamondCut: Incorrect FacetCutAction");
}
return _selectorCount;
}
function initializeDiamondCut(address _init, bytes memory _calldata) internal {
if (_init == address(0)) {
require(_calldata.length == 0, "LibDiamondCut: _init is address(0) but_calldata is not empty");
} else {
require(_calldata.length > 0, "LibDiamondCut: _calldata is empty but _init is not address(0)");
if (_init != address(this)) {
enforceHasContractCode(_init, "LibDiamondCut: _init address has no code");
}
(bool success, bytes memory error) = _init.delegatecall(_calldata);
if (!success) {
if (error.length > 0) {
// bubble up the error
revert(string(error));
} else {
revert("LibDiamondCut: _init function reverted");
}
}
}
}
function enforceHasContractCode(address _contract, string memory _errorMessage) internal view {
uint256 contractSize;
assembly {
contractSize := extcodesize(_contract)
}
require(contractSize > 0, _errorMessage);
}
}